mirror of https://github.com/langgenius/dify.git
feat(trigger): enhance trigger subscription management and processing
- Refactor trigger provider classes to improve naming consistency and clarity - Introduce new methods for managing trigger subscriptions, including validation and dispatching - Update API endpoints to reflect changes in subscription handling - Implement logging and request management for endpoint interactions - Enhance data models to support subscription attributes and lifecycle management Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
6acc77d86d
commit
2f08306695
|
|
@ -19,10 +19,6 @@ from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
|
|
@ -87,7 +83,7 @@ class WebhookTriggerApi(Resource):
|
|||
base_url = dify_config.SERVICE_API_URL
|
||||
webhook_trigger.webhook_url = f"{base_url}/triggers/webhook/{webhook_trigger.webhook_id}"
|
||||
webhook_trigger.webhook_debug_url = f"{base_url}/triggers/webhook-debug/{webhook_trigger.webhook_id}"
|
||||
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
@setup_required
|
||||
|
|
@ -231,7 +227,7 @@ class AppTriggerEnableApi(Resource):
|
|||
trigger.icon = url_prefix + trigger.provider_name + "/icon"
|
||||
else:
|
||||
trigger.icon = ""
|
||||
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class TriggerProviderListApi(Resource):
|
|||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderSubscriptionListApi(Resource):
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -54,7 +54,7 @@ class TriggerProviderSubscriptionListApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerProviderSubscriptionsAddApi(Resource):
|
||||
class TriggerSubscriptionsAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -99,7 +99,7 @@ class TriggerProviderSubscriptionsAddApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerProviderSubscriptionsDeleteApi(Resource):
|
||||
class TriggerSubscriptionsDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -125,7 +125,7 @@ class TriggerProviderSubscriptionsDeleteApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerProviderOAuthAuthorizeApi(Resource):
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -189,7 +189,7 @@ class TriggerProviderOAuthAuthorizeApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerProviderOAuthCallbackApi(Resource):
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
|
|
@ -252,7 +252,7 @@ class TriggerProviderOAuthCallbackApi(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerProviderOAuthRefreshTokenApi(Resource):
|
||||
class TriggerOAuthRefreshTokenApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -278,7 +278,7 @@ class TriggerProviderOAuthRefreshTokenApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerProviderOAuthClientManageApi(Resource):
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -381,25 +381,25 @@ class TriggerProviderOAuthClientManageApi(Resource):
|
|||
# Trigger provider endpoints
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-providers")
|
||||
api.add_resource(
|
||||
TriggerProviderSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/list"
|
||||
TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/list"
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/add"
|
||||
TriggerSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/add"
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderSubscriptionsDeleteApi,
|
||||
TriggerSubscriptionsDeleteApi,
|
||||
"/workspaces/current/trigger-provider/subscriptions/<path:subscription_id>/delete",
|
||||
)
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerProviderOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthRefreshTokenApi,
|
||||
TriggerOAuthRefreshTokenApi,
|
||||
"/workspaces/current/trigger-provider/subscriptions/<path:subscription_id>/oauth/refresh",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
|
||||
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,21 +1,45 @@
|
|||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_validation_service import TriggerSubscriptionValidationService
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
@bp.route("/trigger/webhook/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_webhook(endpoint_id: str):
|
||||
|
||||
@bp.route(
|
||||
"/trigger/endpoint/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
|
||||
)
|
||||
@bp.route(
|
||||
"/trigger/endpoint-debug/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
|
||||
)
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionValidationService.process_validating_endpoint,
|
||||
]
|
||||
try:
|
||||
return TriggerService.process_webhook(endpoint_id, request)
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
raise NotFound("Endpoint not found")
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Any, Literal, Optional
|
||||
|
||||
from flask import Response
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
|
|
@ -237,3 +238,24 @@ class RequestFetchAppInfo(BaseModel):
|
|||
"""
|
||||
|
||||
app_id: str
|
||||
|
||||
|
||||
class TriggerInvokeResponse(BaseModel):
|
||||
event: dict[str, Any]
|
||||
|
||||
|
||||
class PluginTriggerDispatchResponse(BaseModel):
|
||||
triggers: list[str]
|
||||
raw_http_response: str
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
valid: bool
|
||||
message: str
|
||||
error: str
|
||||
|
||||
class TriggerDispatchResponse:
|
||||
triggers: list[str]
|
||||
response: Response
|
||||
|
|
|
|||
|
|
@ -1,14 +1,26 @@
|
|||
import binascii
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginTriggerProviderEntity
|
||||
from flask import Request
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.request import (
|
||||
PluginTriggerDispatchResponse,
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
TriggerSubscriptionResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.http_parser import deserialize_response, serialize_request
|
||||
from core.trigger.entities.entities import Subscription
|
||||
|
||||
|
||||
class PluginTriggerManager(BasePluginClient):
|
||||
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||
"""
|
||||
Fetch tool providers for the given tenant.
|
||||
Fetch trigger providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
|
|
@ -31,7 +43,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in provider.declaration.triggers:
|
||||
trigger.identity.provider = provider.declaration.identity.name
|
||||
|
||||
|
|
@ -39,7 +51,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
|
||||
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
Fetch trigger provider for the given tenant and plugin.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
|
|
@ -65,3 +77,224 @@ class PluginTriggerManager(BasePluginClient):
|
|||
trigger.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def invoke_trigger(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
trigger: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
parameters: dict[str, Any],
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Invoke a trigger with the given parameters.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/invoke",
|
||||
TriggerInvokeResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"trigger": trigger,
|
||||
"credentials": credentials,
|
||||
"credential_type": credential_type,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerInvokeResponse(event=resp.event)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> TriggerValidateProviderCredentialsResponse:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/validate_credentials",
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
return TriggerValidateProviderCredentialsResponse(valid=False, message="No response", error="No response")
|
||||
|
||||
def dispatch_event(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: dict[str, Any],
|
||||
request: Request,
|
||||
) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch an event to triggers.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/dispatch_event",
|
||||
PluginTriggerDispatchResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription,
|
||||
"raw_http_request": binascii.hexlify(serialize_request(request)).decode(),
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return TriggerDispatchResponse(
|
||||
triggers=resp.triggers,
|
||||
response=deserialize_response(binascii.unhexlify(resp.raw_http_response.encode())),
|
||||
)
|
||||
|
||||
raise ValueError("No response received from plugin daemon for dispatch event")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: dict[str, Any],
|
||||
endpoint: str,
|
||||
parameters: dict[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Subscribe to a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/subscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
"endpoint": endpoint,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for subscribe")
|
||||
|
||||
def unsubscribe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: dict[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Unsubscribe from a trigger.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/unsubscribe",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for unsubscribe")
|
||||
|
||||
def refresh(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: dict[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Refresh a trigger subscription.
|
||||
"""
|
||||
trigger_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/trigger/refresh",
|
||||
TriggerSubscriptionResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": trigger_provider_id.provider_name,
|
||||
"subscription": subscription.model_dump(),
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": trigger_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("No response received from plugin daemon for refresh")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,182 @@
|
|||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
"""
|
||||
Convert a Request object to raw HTTP data.
|
||||
|
||||
Args:
|
||||
request: The Request object to convert.
|
||||
|
||||
Returns:
|
||||
The raw HTTP data as bytes.
|
||||
"""
|
||||
# Start with the request line
|
||||
method = request.method
|
||||
path = request.full_path
|
||||
# Remove trailing ? if there's no query string
|
||||
path = path.removesuffix("?")
|
||||
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||
|
||||
# Add headers
|
||||
for header_name, header_value in request.headers.items():
|
||||
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||
|
||||
# Add empty line to separate headers from body
|
||||
raw_data += b"\r\n"
|
||||
|
||||
# Add body if exists
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw_data += body
|
||||
|
||||
return raw_data
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
"""
|
||||
Convert raw HTTP data to a Request object.
|
||||
|
||||
Args:
|
||||
raw_data: The raw HTTP data as bytes.
|
||||
|
||||
Returns:
|
||||
A Flask Request object.
|
||||
"""
|
||||
lines = raw_data.split(b"\r\n")
|
||||
|
||||
# Parse request line
|
||||
request_line = lines[0].decode("utf-8")
|
||||
parts = request_line.split(" ", 2) # Split into max 3 parts
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
# Parse headers
|
||||
headers = Headers()
|
||||
body_start = 0
|
||||
for i in range(1, len(lines)):
|
||||
line = lines[i]
|
||||
if line == b"":
|
||||
body_start = i + 1
|
||||
break
|
||||
if b":" in line:
|
||||
header_line = line.decode("utf-8")
|
||||
name, value = header_line.split(":", 1)
|
||||
headers[name.strip()] = value.strip()
|
||||
|
||||
# Extract body
|
||||
body = b""
|
||||
if body_start > 0 and body_start < len(lines):
|
||||
body = b"\r\n".join(lines[body_start:])
|
||||
|
||||
# Create environ for Request
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path.split("?")[0] if "?" in path else path,
|
||||
"QUERY_STRING": path.split("?")[1] if "?" in path else "",
|
||||
"SERVER_NAME": headers.get("Host", "localhost").split(":")[0],
|
||||
"SERVER_PORT": headers.get("Host", "localhost:80").split(":")[1] if ":" in headers.get("Host", "") else "80",
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https" if headers.get("X-Forwarded-Proto") == "https" else "http",
|
||||
"CONTENT_LENGTH": str(len(body)) if body else "0",
|
||||
"CONTENT_TYPE": headers.get("Content-Type", ""),
|
||||
}
|
||||
|
||||
# Add headers to environ
|
||||
for header_name, header_value in headers.items():
|
||||
env_name = f"HTTP_{header_name.upper().replace('-', '_')}"
|
||||
if header_name.upper() not in ["CONTENT-TYPE", "CONTENT-LENGTH"]:
|
||||
environ[env_name] = header_value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
"""
|
||||
Convert a Response object to raw HTTP data.
|
||||
|
||||
Args:
|
||||
response: The Response object to convert.
|
||||
|
||||
Returns:
|
||||
The raw HTTP data as bytes.
|
||||
"""
|
||||
# Start with the status line
|
||||
protocol = "HTTP/1.1"
|
||||
status_code = response.status_code
|
||||
status_text = response.status.split(" ", 1)[1] if " " in response.status else "OK"
|
||||
raw_data = f"{protocol} {status_code} {status_text}\r\n".encode()
|
||||
|
||||
# Add headers
|
||||
for header_name, header_value in response.headers.items():
|
||||
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||
|
||||
# Add empty line to separate headers from body
|
||||
raw_data += b"\r\n"
|
||||
|
||||
# Add body if exists
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw_data += body
|
||||
|
||||
return raw_data
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
"""
|
||||
Convert raw HTTP data to a Response object.
|
||||
|
||||
Args:
|
||||
raw_data: The raw HTTP data as bytes.
|
||||
|
||||
Returns:
|
||||
A Flask Response object.
|
||||
"""
|
||||
lines = raw_data.split(b"\r\n")
|
||||
|
||||
# Parse status line
|
||||
status_line = lines[0].decode("utf-8")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
protocol = parts[0]
|
||||
status_code = int(parts[1])
|
||||
status_text = parts[2] if len(parts) > 2 else "OK"
|
||||
|
||||
# Parse headers
|
||||
headers: dict[str, Any] = {}
|
||||
body_start = 0
|
||||
for i in range(1, len(lines)):
|
||||
line = lines[i]
|
||||
if line == b"":
|
||||
body_start = i + 1
|
||||
break
|
||||
if b":" in line:
|
||||
header_line = line.decode("utf-8")
|
||||
name, value = header_line.split(":", 1)
|
||||
headers[name.strip()] = value.strip()
|
||||
|
||||
# Extract body
|
||||
body = b""
|
||||
if body_start > 0 and body_start < len(lines):
|
||||
body = b"\r\n".join(lines[body_start:])
|
||||
|
||||
# Create Response object
|
||||
response = Response(
|
||||
response=body,
|
||||
status=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
@ -7,6 +7,7 @@ from core.entities.provider_entities import ProviderConfig
|
|||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import (
|
||||
OAuthSchema,
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerDescription,
|
||||
TriggerEntity,
|
||||
|
|
@ -39,5 +40,27 @@ class TriggerApiEntity(BaseModel):
|
|||
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||
|
||||
class SubscriptionValidation(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
provider_id: str
|
||||
endpoint: str
|
||||
parameters: dict
|
||||
properties: dict
|
||||
credentials: dict
|
||||
credential_type: str
|
||||
credential_expires_at: int
|
||||
expires_at: int
|
||||
|
||||
def to_subscription(self) -> Subscription:
|
||||
return Subscription(
|
||||
expires_at=self.expires_at,
|
||||
endpoint=self.endpoint,
|
||||
parameters=self.parameters,
|
||||
properties=self.properties,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]
|
||||
|
|
|
|||
|
|
@ -143,12 +143,19 @@ class Subscription(BaseModel):
|
|||
Contains all information needed to manage the subscription lifecycle.
|
||||
"""
|
||||
|
||||
expire_at: int = Field(
|
||||
expires_at: int = Field(
|
||||
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
|
||||
)
|
||||
|
||||
metadata: dict[str, Any] = Field(
|
||||
..., description="Metadata about the subscription in the external service, defined in subscription_schema"
|
||||
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
|
||||
|
||||
parameters: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="""The parameters of the subscription, this is the creation parameters.
|
||||
Only available when creating a new subscription by credentials(auto subscription), not manual subscription""",
|
||||
)
|
||||
properties: dict[str, Any] = Field(
|
||||
..., description="Subscription data containing all properties and provider-specific information"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,19 @@ Trigger Provider Controller for managing trigger providers
|
|||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import (
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
|
|
@ -90,18 +97,36 @@ class PluginTriggerProviderController:
|
|||
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
return self.entity.subscription_schema
|
||||
# Return the parameters schema from the subscription schema
|
||||
if self.entity.subscription_schema and self.entity.subscription_schema.parameters_schema:
|
||||
return self.entity.subscription_schema.parameters_schema
|
||||
return []
|
||||
|
||||
def validate_credentials(self, credentials: dict) -> None:
|
||||
def validate_credentials(self, credentials: dict) -> TriggerValidateProviderCredentialsResponse:
|
||||
"""
|
||||
Validate credentials against schema
|
||||
|
||||
:param credentials: Credentials to validate
|
||||
:raises ValueError: If credentials are invalid
|
||||
:return: Validation response
|
||||
"""
|
||||
# First validate against schema
|
||||
for config in self.entity.credentials_schema:
|
||||
if config.required and config.name not in credentials:
|
||||
raise ValueError(f"Missing required credential field: {config.name}")
|
||||
return TriggerValidateProviderCredentialsResponse(
|
||||
valid=False,
|
||||
message=f"Missing required credential field: {config.name}",
|
||||
error=f"Missing required credential field: {config.name}",
|
||||
)
|
||||
|
||||
# Then validate with the plugin daemon
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
return manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="system", # System validation
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
|
|
@ -143,58 +168,127 @@ class PluginTriggerProviderController:
|
|||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""Check if this provider needs credentials"""
|
||||
return len(self.get_supported_credential_types()) > 0
|
||||
def dispatch(self,user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch a trigger through plugin runtime
|
||||
|
||||
def execute_trigger(self, trigger_name: str, parameters: dict, credentials: dict) -> dict:
|
||||
:param user_id: User ID
|
||||
:param request: Flask request object
|
||||
:param subscription: Subscription
|
||||
:return: Dispatch response with triggers and raw HTTP response
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.dispatch_event(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
subscription=subscription.model_dump(),
|
||||
request=request,
|
||||
)
|
||||
return response
|
||||
|
||||
def invoke_trigger(
|
||||
self,
|
||||
user_id: str,
|
||||
trigger_name: str,
|
||||
parameters: dict,
|
||||
credentials: dict,
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Execute a trigger through plugin runtime
|
||||
|
||||
:param user_id: User ID
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Execution result
|
||||
:param credential_type: Credential type
|
||||
:param request: Request
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
logger.info("Executing trigger %s for plugin %s", trigger_name, self.plugin_id)
|
||||
return {
|
||||
"success": True,
|
||||
"trigger": trigger_name,
|
||||
"plugin": self.plugin_id,
|
||||
"result": "Trigger executed successfully",
|
||||
}
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
def subscribe_trigger(self, trigger_name: str, subscription_params: dict, credentials: dict) -> Subscription:
|
||||
return manager.invoke_trigger(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
trigger=trigger_name,
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
request=request,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
def subscribe_trigger(self, user_id: str, endpoint: str, parameters: dict, credentials: dict) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger through plugin runtime
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:param user_id: User ID
|
||||
:param endpoint: Subscription endpoint
|
||||
:param subscription_params: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
logger.info("Subscribing to trigger %s for plugin %s", trigger_name, self.plugin_id)
|
||||
return Subscription(
|
||||
expire_at=int(time.time()) + 86400, # 24 hours from now
|
||||
metadata={
|
||||
"subscription_id": f"{self.plugin_id}_{trigger_name}_{time.time()}",
|
||||
"webhook_url": f"/triggers/webhook/{self.plugin_id}/{trigger_name}",
|
||||
**subscription_params,
|
||||
},
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.subscribe(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
endpoint=endpoint,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
def unsubscribe_trigger(self, trigger_name: str, subscription_metadata: dict, credentials: dict) -> Unsubscription:
|
||||
return Subscription.model_validate(response.subscription)
|
||||
|
||||
def unsubscribe_trigger(self, user_id: str, subscription: Subscription, credentials: dict) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger through plugin runtime
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_metadata: Subscription metadata
|
||||
:param user_id: User ID
|
||||
:param subscription: Subscription metadata
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id)
|
||||
return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}")
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.unsubscribe(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
subscription=subscription,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return Unsubscription.model_validate(response.subscription)
|
||||
|
||||
def refresh_trigger(self, subscription: Subscription, credentials: dict) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription through plugin runtime
|
||||
|
||||
:param subscription: Subscription metadata
|
||||
:param credentials: Provider credentials
|
||||
:return: Refreshed subscription result
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
|
||||
response = manager.refresh(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="system", # System refresh
|
||||
provider=str(provider_id),
|
||||
subscription=subscription,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
return Subscription.model_validate(response.subscription)
|
||||
|
||||
|
||||
__all__ = ["PluginTriggerProviderController"]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,10 @@ Trigger Manager for loading and managing trigger providers and triggers
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.request import TriggerInvokeResponse
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
|
|
@ -14,6 +17,7 @@ from core.trigger.entities.entities import (
|
|||
Unsubscription,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -123,75 +127,90 @@ class TriggerManager:
|
|||
:return: Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
cls.get_trigger_provider(tenant_id, provider_id).validate_credentials(credentials)
|
||||
return True, ""
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
validation_result = provider.validate_credentials(credentials)
|
||||
return validation_result.valid, validation_result.message if not validation_result.valid else ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def execute_trigger(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str, parameters: dict, credentials: dict
|
||||
) -> dict:
|
||||
def invoke_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
parameters: dict,
|
||||
credentials: dict,
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Execute a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:param credential_type: Credential type
|
||||
:param request: Request
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
trigger = cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
trigger = provider.get_trigger(trigger_name)
|
||||
if not trigger:
|
||||
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).execute_trigger(trigger_name, parameters, credentials)
|
||||
return provider.invoke_trigger(user_id, trigger_name, parameters, credentials, credential_type, request)
|
||||
|
||||
@classmethod
|
||||
def subscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
subscription_params: dict,
|
||||
endpoint: str,
|
||||
parameters: dict,
|
||||
credentials: dict,
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger (e.g., register webhook)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_params: Subscription parameters
|
||||
:param endpoint: Subscription endpoint
|
||||
:param parameters: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).subscribe_trigger(
|
||||
trigger_name, subscription_params, credentials
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.subscribe_trigger(
|
||||
user_id=user_id, endpoint=endpoint, parameters=parameters, credentials=credentials
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unsubscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
subscription_metadata: dict,
|
||||
subscription: Subscription,
|
||||
credentials: dict,
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_metadata: Subscription metadata from subscribe operation
|
||||
:param subscription: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).unsubscribe_trigger(
|
||||
trigger_name, subscription_metadata, credentials
|
||||
)
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]:
|
||||
|
|
@ -204,6 +223,27 @@ class TriggerManager:
|
|||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
|
||||
|
||||
@classmethod
|
||||
def refresh_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
subscription: Subscription,
|
||||
credentials: dict,
|
||||
) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Refreshed subscription result
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(trigger_name, subscription, credentials)
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = ["TriggerManager"]
|
||||
|
|
|
|||
|
|
@ -23,4 +23,4 @@ webhook_trigger_fields = {
|
|||
"node_id": fields.String,
|
||||
"triggered_by": fields.String,
|
||||
"created_at": fields.DateTime(dt_format="iso8601"),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from sqlalchemy import DateTime, Index, Integer, String, UniqueConstraint, func
|
|||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.entities.api_entities import SubscriptionValidation, TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.entities.entities import Subscription
|
||||
from models.base import Base
|
||||
from models.types import StringUUID
|
||||
|
||||
|
|
@ -24,6 +25,7 @@ class TriggerSubscription(Base):
|
|||
sa.PrimaryKeyConstraint("id", name="trigger_subscription_pkey"),
|
||||
Index("idx_trigger_subscriptions_tenant_provider", "tenant_id", "provider_id"),
|
||||
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_subscription"),
|
||||
UniqueConstraint("endpoint", name="unique_trigger_subscription_endpoint"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
|
|
@ -33,8 +35,9 @@ class TriggerSubscription(Base):
|
|||
provider_id: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
configuration: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription configuration JSON")
|
||||
properties: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
|
||||
credentials: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription credentials JSON")
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
|
|
@ -57,6 +60,14 @@ class TriggerSubscription(Base):
|
|||
# Check if token expires in next 3 minutes
|
||||
return (self.credential_expires_at - 180) < int(time.time())
|
||||
|
||||
def to_entity(self) -> Subscription:
|
||||
return Subscription(
|
||||
expires_at=self.expires_at,
|
||||
endpoint=self.endpoint,
|
||||
parameters=self.parameters,
|
||||
properties=self.properties,
|
||||
)
|
||||
|
||||
def to_api_entity(self) -> TriggerProviderSubscriptionApiEntity:
|
||||
return TriggerProviderSubscriptionApiEntity(
|
||||
id=self.id,
|
||||
|
|
@ -66,7 +77,6 @@ class TriggerSubscription(Base):
|
|||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
|
||||
# system level trigger oauth client params
|
||||
class TriggerOAuthSystemClient(Base):
|
||||
__tablename__ = "trigger_oauth_system_clients"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import re
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request, Response
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -15,7 +16,11 @@ from core.plugin.entities.plugin import TriggerProviderID
|
|||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.entities.api_entities import (
|
||||
SubscriptionValidation,
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
)
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import (
|
||||
create_trigger_provider_encrypter_for_subscription,
|
||||
|
|
@ -32,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
class TriggerProviderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 100
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
@classmethod
|
||||
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
|
||||
|
|
@ -553,3 +558,24 @@ class TriggerProviderService:
|
|||
except Exception as e:
|
||||
logger.warning("Error generating provider name")
|
||||
return f"{credential_type.get_name()} 1"
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first()
|
||||
return subscription
|
||||
|
||||
@classmethod
|
||||
def get_subscription_validation(cls, endpoint_id: str) -> SubscriptionValidation | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
cache_key = f"trigger:subscription:validation:endpoint:{endpoint_id}"
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionValidation.model_validate(json.loads(subscription_cache))
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import logging
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionValidationService:
|
||||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
|
||||
@classmethod
|
||||
def append_validation_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||
"""
|
||||
Append the validation request log to Redis.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def process_validating_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Process a temporary endpoint request.
|
||||
|
||||
:param endpoint_id: The endpoint identifier
|
||||
:param request: The Flask request object
|
||||
:return: The Flask response object
|
||||
"""
|
||||
# check if validation endpoint exists
|
||||
subscription_validation = TriggerProviderService.get_subscription_validation(endpoint_id)
|
||||
if not subscription_validation:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller = TriggerManager.get_trigger_provider(
|
||||
subscription_validation.tenant_id, TriggerProviderID(subscription_validation.provider_id)
|
||||
)
|
||||
response = controller.dispatch(
|
||||
user_id=subscription_validation.user_id,
|
||||
request=request,
|
||||
subscription=subscription_validation.to_subscription(),
|
||||
)
|
||||
# append the request log
|
||||
cls.append_validation_request_log(endpoint_id, request, response.response)
|
||||
return response.response
|
||||
|
|
@ -1,23 +1,192 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerService:
|
||||
__MAX_REQUEST_LOG_COUNT__ = 10
|
||||
__TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
__ENDPOINT_REQUEST_CACHE_COUNT__ = 10
|
||||
__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
# Lua script for atomic write with time & count based cleanup
|
||||
__LUA_SCRIPT__ = """
|
||||
-- KEYS[1] = zset key
|
||||
-- ARGV[1] = max_count (maximum number of entries to keep)
|
||||
-- ARGV[2] = min_ts_ms (minimum timestamp to keep = now_ms - ttl_ms)
|
||||
-- ARGV[3] = now_ms (current timestamp in milliseconds)
|
||||
-- ARGV[4] = member (log entry JSON)
|
||||
|
||||
local key = KEYS[1]
|
||||
local maxCount = tonumber(ARGV[1])
|
||||
local minTs = tonumber(ARGV[2])
|
||||
local nowMs = tonumber(ARGV[3])
|
||||
local member = ARGV[4]
|
||||
|
||||
-- 1) Add new entry with timestamp as score
|
||||
redis.call('ZADD', key, nowMs, member)
|
||||
|
||||
-- 2) Remove entries older than minTs (time-based cleanup)
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', minTs)
|
||||
|
||||
-- 3) Remove oldest entries if count exceeds maxCount (count-based cleanup)
|
||||
local n = redis.call('ZCARD', key)
|
||||
if n > maxCount then
|
||||
redis.call('ZREMRANGEBYRANK', key, 0, n - maxCount - 1) -- 0 is oldest
|
||||
end
|
||||
|
||||
return n
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def process_webhook(cls, webhook_id: str, request: Request) -> Response:
|
||||
"""Extract and process data from incoming webhook request."""
|
||||
# TODO redis slidingwindow log, save the recent request log in redis, rollover the log when the window is full
|
||||
def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""Extract and process data from incoming endpoint request."""
|
||||
subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id)
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id)
|
||||
if not controller:
|
||||
return None
|
||||
|
||||
dispatch_response = controller.dispatch(
|
||||
user_id=subscription.user_id, request=request, subscription=subscription.to_entity()
|
||||
)
|
||||
|
||||
# TODO find the trigger subscription
|
||||
# TODO invoke triggers
|
||||
# dispatch_response.triggers
|
||||
|
||||
# TODO fetch the trigger controller
|
||||
return dispatch_response.response
|
||||
|
||||
# TODO dispatch by the trigger controller
|
||||
@classmethod
|
||||
def log_endpoint_request(cls, endpoint_id: str, request: Request) -> int:
|
||||
"""
|
||||
Log the endpoint request to Redis using ZSET for rolling log with time & count based retention.
|
||||
|
||||
# TODO using the dispatch result(events) to invoke the trigger events
|
||||
raise NotImplementedError("Not implemented")
|
||||
Args:
|
||||
endpoint_id: The endpoint identifier
|
||||
request: The Flask request object
|
||||
|
||||
Returns:
|
||||
The current number of logged requests for this endpoint
|
||||
"""
|
||||
try:
|
||||
# Prepare timestamp
|
||||
now_ms = int(time.time() * 1000)
|
||||
min_ts = now_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
|
||||
|
||||
# Extract request data
|
||||
request_data = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"timestamp": now_ms,
|
||||
"method": request.method,
|
||||
"path": request.path,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": request.args.to_dict(flat=False) if request.args else {},
|
||||
"body": None,
|
||||
"remote_addr": request.remote_addr,
|
||||
}
|
||||
|
||||
# Try to get request body if it exists
|
||||
if request.is_json:
|
||||
try:
|
||||
request_data["body"] = request.get_json(force=True)
|
||||
except Exception:
|
||||
request_data["body"] = request.get_data(as_text=True)
|
||||
elif request.data:
|
||||
request_data["body"] = request.get_data(as_text=True)
|
||||
|
||||
# Serialize to JSON
|
||||
member = json.dumps(request_data, separators=(",", ":"))
|
||||
|
||||
# Execute Lua script atomically
|
||||
key = f"trigger:endpoint_requests:{endpoint_id}"
|
||||
count = redis_client.eval(
|
||||
cls.__LUA_SCRIPT__,
|
||||
1, # number of keys
|
||||
key, # KEYS[1]
|
||||
str(cls.__ENDPOINT_REQUEST_CACHE_COUNT__), # ARGV[1] - max count
|
||||
str(min_ts), # ARGV[2] - minimum timestamp
|
||||
str(now_ms), # ARGV[3] - current timestamp
|
||||
member, # ARGV[4] - log entry
|
||||
)
|
||||
|
||||
logger.debug("Logged request for endpoint %s, current count: %s", endpoint_id, count)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to log endpoint request for %s", endpoint_id, exc_info=e)
|
||||
# Don't fail the main request processing if logging fails
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def get_recent_endpoint_requests(
|
||||
cls, endpoint_id: str, limit: int = 100, start_time_ms: Optional[int] = None, end_time_ms: Optional[int] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve recent logged requests for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_id: The endpoint identifier
|
||||
limit: Maximum number of entries to return
|
||||
start_time_ms: Start timestamp in milliseconds (optional)
|
||||
end_time_ms: End timestamp in milliseconds (optional, defaults to now)
|
||||
|
||||
Returns:
|
||||
List of request log entries, newest first
|
||||
"""
|
||||
try:
|
||||
key = f"trigger:endpoint_requests:{endpoint_id}"
|
||||
|
||||
# Set time bounds
|
||||
if end_time_ms is None:
|
||||
end_time_ms = int(time.time() * 1000)
|
||||
if start_time_ms is None:
|
||||
start_time_ms = end_time_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
|
||||
|
||||
# Get entries in reverse order (newest first)
|
||||
entries = redis_client.zrevrangebyscore(key, max=end_time_ms, min=start_time_ms, start=0, num=limit)
|
||||
|
||||
# Parse JSON entries
|
||||
requests = []
|
||||
for entry in entries:
|
||||
try:
|
||||
requests.append(json.loads(entry))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse log entry: %s", entry)
|
||||
|
||||
return requests
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to retrieve endpoint requests for %s", endpoint_id, exc_info=e)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def clear_endpoint_requests(cls, endpoint_id: str) -> bool:
|
||||
"""
|
||||
Clear all logged requests for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_id: The endpoint identifier
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = f"trigger:endpoint_requests:{endpoint_id}"
|
||||
redis_client.delete(key)
|
||||
logger.info("Cleared request logs for endpoint %s", endpoint_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception("Failed to clear endpoint requests for %s", endpoint_id, exc_info=e)
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,655 @@
|
|||
import pytest
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.utils.http_parser import (
|
||||
deserialize_request,
|
||||
deserialize_response,
|
||||
serialize_request,
|
||||
serialize_response,
|
||||
)
|
||||
|
||||
|
||||
class TestSerializeRequest:
|
||||
def test_serialize_simple_get_request(self):
|
||||
# Create a simple GET request
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/test",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert raw_data.startswith(b"GET /api/test HTTP/1.1\r\n")
|
||||
assert b"\r\n\r\n" in raw_data # Empty line between headers and body
|
||||
|
||||
def test_serialize_request_with_query_params(self):
|
||||
# Create a GET request with query parameters
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/search",
|
||||
"QUERY_STRING": "q=test&limit=10",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert raw_data.startswith(b"GET /api/search?q=test&limit=10 HTTP/1.1\r\n")
|
||||
|
||||
def test_serialize_post_request_with_body(self):
|
||||
# Create a POST request with body
|
||||
from io import BytesIO
|
||||
|
||||
body = b'{"name": "test", "value": 123}'
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/data",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": "application/json",
|
||||
"HTTP_CONTENT_TYPE": "application/json",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/data HTTP/1.1\r\n" in raw_data
|
||||
assert b"Content-Type: application/json" in raw_data
|
||||
assert raw_data.endswith(body)
|
||||
|
||||
def test_serialize_request_with_custom_headers(self):
|
||||
# Create a request with custom headers
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/test",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
"HTTP_AUTHORIZATION": "Bearer token123",
|
||||
"HTTP_X_CUSTOM_HEADER": "custom-value",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"Authorization: Bearer token123" in raw_data
|
||||
assert b"X-Custom-Header: custom-value" in raw_data
|
||||
|
||||
|
||||
class TestDeserializeRequest:
|
||||
def test_deserialize_simple_get_request(self):
|
||||
raw_data = b"GET /api/test HTTP/1.1\r\nHost: localhost:8000\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/api/test"
|
||||
assert request.headers.get("Host") == "localhost:8000"
|
||||
|
||||
def test_deserialize_request_with_query_params(self):
|
||||
raw_data = b"GET /api/search?q=test&limit=10 HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/api/search"
|
||||
assert request.query_string == b"q=test&limit=10"
|
||||
assert request.args.get("q") == "test"
|
||||
assert request.args.get("limit") == "10"
|
||||
|
||||
def test_deserialize_post_request_with_body(self):
|
||||
body = b'{"name": "test", "value": 123}'
|
||||
raw_data = (
|
||||
b"POST /api/data HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/data"
|
||||
assert request.content_type == "application/json"
|
||||
assert request.get_data() == body
|
||||
|
||||
def test_deserialize_request_with_custom_headers(self):
|
||||
raw_data = (
|
||||
b"GET /api/protected HTTP/1.1\r\n"
|
||||
b"Host: api.example.com\r\n"
|
||||
b"Authorization: Bearer token123\r\n"
|
||||
b"X-Custom-Header: custom-value\r\n"
|
||||
b"User-Agent: TestClient/1.0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.headers.get("Authorization") == "Bearer token123"
|
||||
assert request.headers.get("X-Custom-Header") == "custom-value"
|
||||
assert request.headers.get("User-Agent") == "TestClient/1.0"
|
||||
|
||||
def test_deserialize_request_with_multiline_body(self):
|
||||
body = b"line1\r\nline2\r\nline3"
|
||||
raw_data = b"PUT /api/text HTTP/1.1\r\nHost: localhost\r\nContent-Type: text/plain\r\n\r\n" + body
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "PUT"
|
||||
assert request.get_data() == body
|
||||
|
||||
def test_deserialize_invalid_request_line(self):
|
||||
raw_data = b"INVALID\r\n\r\n" # Only one part, should fail
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid request line"):
|
||||
deserialize_request(raw_data)
|
||||
|
||||
def test_roundtrip_request(self):
|
||||
# Test that serialize -> deserialize produces equivalent request
|
||||
from io import BytesIO
|
||||
|
||||
body = b"test body content"
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/echo",
|
||||
"QUERY_STRING": "format=json",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8080",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": "text/plain",
|
||||
"HTTP_CONTENT_TYPE": "text/plain",
|
||||
"HTTP_X_REQUEST_ID": "req-123",
|
||||
}
|
||||
original_request = Request(environ)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_request(original_request)
|
||||
restored_request = deserialize_request(raw_data)
|
||||
|
||||
# Verify key properties are preserved
|
||||
assert restored_request.method == original_request.method
|
||||
assert restored_request.path == original_request.path
|
||||
assert restored_request.query_string == original_request.query_string
|
||||
assert restored_request.get_data() == body
|
||||
assert restored_request.headers.get("X-Request-Id") == "req-123"
|
||||
|
||||
|
||||
class TestSerializeResponse:
|
||||
def test_serialize_simple_response(self):
|
||||
response = Response("Hello, World!", status=200)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert raw_data.startswith(b"HTTP/1.1 200 OK\r\n")
|
||||
assert b"\r\n\r\n" in raw_data
|
||||
assert raw_data.endswith(b"Hello, World!")
|
||||
|
||||
def test_serialize_response_with_headers(self):
|
||||
response = Response(
|
||||
'{"status": "success"}',
|
||||
status=201,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-Id": "req-456",
|
||||
},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 201 CREATED\r\n" in raw_data
|
||||
assert b"Content-Type: application/json" in raw_data
|
||||
assert b"X-Request-Id: req-456" in raw_data
|
||||
assert raw_data.endswith(b'{"status": "success"}')
|
||||
|
||||
def test_serialize_error_response(self):
|
||||
response = Response(
|
||||
"Not Found",
|
||||
status=404,
|
||||
headers={"Content-Type": "text/plain"},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 404 NOT FOUND\r\n" in raw_data
|
||||
assert b"Content-Type: text/plain" in raw_data
|
||||
assert raw_data.endswith(b"Not Found")
|
||||
|
||||
def test_serialize_response_without_body(self):
|
||||
response = Response(status=204) # No Content
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 204 NO CONTENT\r\n" in raw_data
|
||||
assert raw_data.endswith(b"\r\n\r\n") # Should end with empty line
|
||||
|
||||
def test_serialize_response_with_binary_body(self):
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05"
|
||||
response = Response(
|
||||
binary_data,
|
||||
status=200,
|
||||
headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 200 OK\r\n" in raw_data
|
||||
assert b"Content-Type: application/octet-stream" in raw_data
|
||||
assert raw_data.endswith(binary_data)
|
||||
|
||||
|
||||
class TestDeserializeResponse:
|
||||
def test_deserialize_simple_response(self):
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, World!"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"Hello, World!"
|
||||
assert response.headers.get("Content-Type") == "text/plain"
|
||||
|
||||
def test_deserialize_response_with_json(self):
|
||||
body = b'{"result": "success", "data": [1, 2, 3]}'
|
||||
raw_data = (
|
||||
b"HTTP/1.1 201 Created\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"X-Custom-Header: test-value\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_data() == body
|
||||
assert response.headers.get("Content-Type") == "application/json"
|
||||
assert response.headers.get("X-Custom-Header") == "test-value"
|
||||
|
||||
def test_deserialize_error_response(self):
|
||||
raw_data = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n<html><body>Page not found</body></html>"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.get_data() == b"<html><body>Page not found</body></html>"
|
||||
|
||||
def test_deserialize_response_without_body(self):
|
||||
raw_data = b"HTTP/1.1 204 No Content\r\n\r\n"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert response.get_data() == b""
|
||||
|
||||
def test_deserialize_response_with_multiline_body(self):
|
||||
body = b"Line 1\r\nLine 2\r\nLine 3"
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n" + body
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == body
|
||||
|
||||
def test_deserialize_response_minimal_status_line(self):
|
||||
# Test with minimal status line (no status text)
|
||||
raw_data = b"HTTP/1.1 200\r\n\r\nOK"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"OK"
|
||||
|
||||
def test_deserialize_invalid_status_line(self):
|
||||
raw_data = b"INVALID\r\n\r\n"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid status line"):
|
||||
deserialize_response(raw_data)
|
||||
|
||||
def test_roundtrip_response(self):
|
||||
# Test that serialize -> deserialize produces equivalent response
|
||||
original_response = Response(
|
||||
'{"message": "test"}',
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-Id": "abc-123",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_response(original_response)
|
||||
restored_response = deserialize_response(raw_data)
|
||||
|
||||
# Verify key properties are preserved
|
||||
assert restored_response.status_code == original_response.status_code
|
||||
assert restored_response.get_data() == original_response.get_data()
|
||||
assert restored_response.headers.get("Content-Type") == "application/json"
|
||||
assert restored_response.headers.get("X-Request-Id") == "abc-123"
|
||||
assert restored_response.headers.get("Cache-Control") == "no-cache"
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_request_with_empty_headers(self):
|
||||
raw_data = b"GET / HTTP/1.1\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/"
|
||||
|
||||
def test_response_with_empty_headers(self):
|
||||
raw_data = b"HTTP/1.1 200 OK\r\n\r\nSuccess"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"Success"
|
||||
|
||||
def test_request_with_special_characters_in_path(self):
|
||||
raw_data = b"GET /api/test%20path?key=%26value HTTP/1.1\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert "/api/test%20path" in request.full_path
|
||||
|
||||
def test_response_with_binary_content(self):
|
||||
binary_body = bytes(range(256)) # All possible byte values
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n" + binary_body
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == binary_body
|
||||
|
||||
|
||||
class TestFileUploads:
|
||||
def test_serialize_request_with_text_file_upload(self):
|
||||
# Test multipart/form-data request with text file
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
|
||||
text_content = "Hello, this is a test file content!\nWith multiple lines."
|
||||
body = (
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n'
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"\r\n"
|
||||
f"{text_content}\r\n"
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="description"\r\n'
|
||||
f"\r\n"
|
||||
f"Test file upload\r\n"
|
||||
f"------{boundary}--\r\n"
|
||||
).encode()
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/upload",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/upload HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b"Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"" in raw_data
|
||||
assert text_content.encode() in raw_data
|
||||
|
||||
def test_deserialize_request_with_text_file_upload(self):
|
||||
# Test deserializing multipart/form-data request with text file
|
||||
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
|
||||
text_content = "Sample text file content\nLine 2\nLine 3"
|
||||
body = (
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="document"; filename="document.txt"\r\n'
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"\r\n"
|
||||
f"{text_content}\r\n"
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="title"\r\n'
|
||||
f"\r\n"
|
||||
f"My Document\r\n"
|
||||
f"------{boundary}--\r\n"
|
||||
).encode()
|
||||
|
||||
raw_data = (
|
||||
b"POST /api/documents HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/documents"
|
||||
assert "multipart/form-data" in request.content_type
|
||||
# The body should contain the multipart data
|
||||
request_body = request.get_data()
|
||||
assert b"document.txt" in request_body
|
||||
assert text_content.encode() in request_body
|
||||
|
||||
def test_serialize_request_with_binary_file_upload(self):
|
||||
# Test multipart/form-data request with binary file (e.g., image)
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----BoundaryString123"
|
||||
# Simulate a small PNG file header
|
||||
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10"
|
||||
|
||||
# Build multipart body
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="image"; filename="test.png"')
|
||||
body_parts.append(b"Content-Type: image/png")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="caption"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"Test image")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/images",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/images HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b"filename=\"test.png\"" in raw_data
|
||||
assert b"Content-Type: image/png" in raw_data
|
||||
assert binary_content in raw_data
|
||||
|
||||
def test_deserialize_request_with_binary_file_upload(self):
|
||||
# Test deserializing multipart/form-data request with binary file
|
||||
boundary = "----BoundaryABC123"
|
||||
# Simulate a small JPEG file header
|
||||
binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00"
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"')
|
||||
body_parts.append(b"Content-Type: image/jpeg")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="album"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"Vacation 2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
raw_data = (
|
||||
b"POST /api/photos HTTP/1.1\r\n"
|
||||
b"Host: api.example.com\r\n"
|
||||
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"Accept: application/json\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/photos"
|
||||
assert "multipart/form-data" in request.content_type
|
||||
assert request.headers.get("Accept") == "application/json"
|
||||
|
||||
# Verify the binary content is preserved
|
||||
request_body = request.get_data()
|
||||
assert b"photo.jpg" in request_body
|
||||
assert b"image/jpeg" in request_body
|
||||
assert binary_content in request_body
|
||||
assert b"Vacation 2024" in request_body
|
||||
|
||||
def test_serialize_request_with_multiple_files(self):
|
||||
# Test request with multiple file uploads
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----MultiFilesBoundary"
|
||||
text_file = b"Text file contents"
|
||||
binary_file = b"\x00\x01\x02\x03\x04\x05"
|
||||
|
||||
body_parts = []
|
||||
# First file (text)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="doc.txt"')
|
||||
body_parts.append(b"Content-Type: text/plain")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(text_file)
|
||||
# Second file (binary)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="data.bin"')
|
||||
body_parts.append(b"Content-Type: application/octet-stream")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_file)
|
||||
# Additional form field
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="folder"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"uploads/2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/batch-upload",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_X_FORWARDED_PROTO": "https",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/batch-upload HTTP/1.1\r\n" in raw_data
|
||||
assert b"doc.txt" in raw_data
|
||||
assert b"data.bin" in raw_data
|
||||
assert text_file in raw_data
|
||||
assert binary_file in raw_data
|
||||
assert b"uploads/2024" in raw_data
|
||||
|
||||
def test_roundtrip_file_upload_request(self):
|
||||
# Test that file upload request survives serialize -> deserialize
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----RoundTripBoundary"
|
||||
file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80"
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"')
|
||||
body_parts.append(b"Content-Type: text/plain; charset=utf-8")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(file_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="metadata"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b'{"encoding": "utf-8", "size": 42}')
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "PUT",
|
||||
"PATH_INFO": "/api/files/123",
|
||||
"QUERY_STRING": "version=2",
|
||||
"SERVER_NAME": "storage.example.com",
|
||||
"SERVER_PORT": "443",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_AUTHORIZATION": "Bearer token123",
|
||||
"HTTP_X_FORWARDED_PROTO": "https",
|
||||
}
|
||||
original_request = Request(environ)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_request(original_request)
|
||||
restored_request = deserialize_request(raw_data)
|
||||
|
||||
# Verify the request is preserved
|
||||
assert restored_request.method == "PUT"
|
||||
assert restored_request.path == "/api/files/123"
|
||||
assert restored_request.query_string == b"version=2"
|
||||
assert "multipart/form-data" in restored_request.content_type
|
||||
assert boundary in restored_request.content_type
|
||||
|
||||
# Verify file content is preserved
|
||||
restored_body = restored_request.get_data()
|
||||
assert b"emoji.txt" in restored_body
|
||||
assert file_content in restored_body
|
||||
assert b'{"encoding": "utf-8", "size": 42}' in restored_body
|
||||
Loading…
Reference in New Issue