diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 9f00df67a4..d17f22fd95 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -155,7 +155,6 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): class TriggerSubscriptionBuilderRequestLogsApi(Resource): - @setup_required @login_required @account_initialization_required @@ -470,7 +469,7 @@ class TriggerOAuthClientManageApi(Resource): # Trigger Subscription -api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-providers") +api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers") api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider//subscriptions/list") api.add_resource( TriggerSubscriptionDeleteApi, diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index def3c897e0..48ec3be5c8 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -68,31 +68,6 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" -class TriggerProviderCredentialsCache(ProviderCredentialsCache): - """Cache for trigger provider credentials""" - - def __init__(self, tenant_id: str, provider_id: str, credential_id: str): - super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id) - - def _generate_cache_key(self, **kwargs) -> str: - tenant_id = kwargs["tenant_id"] - provider_id = kwargs["provider_id"] - credential_id = kwargs["credential_id"] - return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}" - - -class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache): - """Cache for trigger provider OAuth client""" - - def __init__(self, tenant_id: str, provider_id: str): - super().__init__(tenant_id=tenant_id, provider_id=provider_id) - - def _generate_cache_key(self, **kwargs) -> str: - tenant_id = kwargs["tenant_id"] - provider_id = kwargs["provider_id"] - return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}" - - class NoOpProviderCredentialCache: """No-op provider credential cache""" diff --git a/api/core/plugin/utils/http_parser.py b/api/core/plugin/utils/http_parser.py index 9ec781360d..47cdcadcb3 100644 --- a/api/core/plugin/utils/http_parser.py +++ b/api/core/plugin/utils/http_parser.py @@ -1,180 +1,159 @@ from io import BytesIO -from typing import Any from flask import Request, Response from werkzeug.datastructures import Headers def serialize_request(request: Request) -> bytes: - """ - Convert a Request object to raw HTTP data. - - Args: - request: The Request object to convert. - - Returns: - The raw HTTP data as bytes. - """ - # Start with the request line method = request.method - path = request.full_path - # Remove trailing ? if there's no query string - path = path.removesuffix("?") - protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") - raw_data = f"{method} {path} {protocol}\r\n".encode() + path = request.full_path.rstrip("?") + raw = f"{method} {path} HTTP/1.1\r\n".encode() - # Add headers - for header_name, header_value in request.headers.items(): - raw_data += f"{header_name}: {header_value}\r\n".encode() + for name, value in request.headers.items(): + raw += f"{name}: {value}\r\n".encode() - # Add empty line to separate headers from body - raw_data += b"\r\n" + raw += b"\r\n" - # Add body if exists body = request.get_data(as_text=False) if body: - raw_data += body + raw += body - return raw_data + return raw def deserialize_request(raw_data: bytes) -> Request: - """ - Convert raw HTTP data to a Request object. + header_end = raw_data.find(b"\r\n\r\n") + if header_end == -1: + header_end = raw_data.find(b"\n\n") + if header_end == -1: + header_data = raw_data + body = b"" + else: + header_data = raw_data[:header_end] + body = raw_data[header_end + 2 :] + else: + header_data = raw_data[:header_end] + body = raw_data[header_end + 4 :] - Args: - raw_data: The raw HTTP data as bytes. + lines = header_data.split(b"\r\n") + if len(lines) == 1 and b"\n" in lines[0]: + lines = header_data.split(b"\n") - Returns: - A Flask Request object. - """ - lines = raw_data.split(b"\r\n") + if not lines or not lines[0]: + raise ValueError("Empty HTTP request") - # Parse request line - request_line = lines[0].decode("utf-8") - parts = request_line.split(" ", 2) # Split into max 3 parts + request_line = lines[0].decode("utf-8", errors="ignore") + parts = request_line.split(" ", 2) if len(parts) < 2: raise ValueError(f"Invalid request line: {request_line}") method = parts[0] - path = parts[1] + full_path = parts[1] protocol = parts[2] if len(parts) > 2 else "HTTP/1.1" - # Parse headers + if "?" in full_path: + path, query_string = full_path.split("?", 1) + else: + path = full_path + query_string = "" + headers = Headers() - body_start = 0 - for i in range(1, len(lines)): - line = lines[i] - if line == b"": - body_start = i + 1 - break - if b":" in line: - header_line = line.decode("utf-8") - name, value = header_line.split(":", 1) - headers[name.strip()] = value.strip() + for line in lines[1:]: + if not line: + continue + line_str = line.decode("utf-8", errors="ignore") + if ":" not in line_str: + continue + name, value = line_str.split(":", 1) + headers.add(name, value.strip()) - # Extract body - body = b"" - if body_start > 0 and body_start < len(lines): - body = b"\r\n".join(lines[body_start:]) + host = headers.get("Host", "localhost") + if ":" in host: + server_name, server_port = host.rsplit(":", 1) + else: + server_name = host + server_port = "80" - # Create environ for Request environ = { "REQUEST_METHOD": method, - "PATH_INFO": path.split("?")[0] if "?" in path else path, - "QUERY_STRING": path.split("?")[1] if "?" in path else "", - "SERVER_NAME": headers.get("Host", "localhost").split(":")[0], - "SERVER_PORT": headers.get("Host", "localhost:80").split(":")[1] if ":" in headers.get("Host", "") else "80", + "PATH_INFO": path, + "QUERY_STRING": query_string, + "SERVER_NAME": server_name, + "SERVER_PORT": server_port, "SERVER_PROTOCOL": protocol, "wsgi.input": BytesIO(body), - "wsgi.url_scheme": "https" if headers.get("X-Forwarded-Proto") == "https" else "http", - "CONTENT_LENGTH": str(len(body)) if body else "0", - "CONTENT_TYPE": headers.get("Content-Type", ""), + "wsgi.url_scheme": "http", } - # Add headers to environ - for header_name, header_value in headers.items(): - env_name = f"HTTP_{header_name.upper().replace('-', '_')}" - if header_name.upper() not in ["CONTENT-TYPE", "CONTENT-LENGTH"]: - environ[env_name] = header_value + if "Content-Type" in headers: + environ["CONTENT_TYPE"] = headers.get("Content-Type") + + if "Content-Length" in headers: + environ["CONTENT_LENGTH"] = headers.get("Content-Length") + elif body: + environ["CONTENT_LENGTH"] = str(len(body)) + + for name, value in headers.items(): + if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"): + continue + env_name = f"HTTP_{name.upper().replace('-', '_')}" + environ[env_name] = value return Request(environ) def serialize_response(response: Response) -> bytes: - """ - Convert a Response object to raw HTTP data. + raw = f"HTTP/1.1 {response.status}\r\n".encode() - Args: - response: The Response object to convert. + for name, value in response.headers.items(): + raw += f"{name}: {value}\r\n".encode() - Returns: - The raw HTTP data as bytes. - """ - # Start with the status line - protocol = "HTTP/1.1" - status_code = response.status_code - status_text = response.status.split(" ", 1)[1] if " " in response.status else "OK" - raw_data = f"{protocol} {status_code} {status_text}\r\n".encode() + raw += b"\r\n" - # Add headers - for header_name, header_value in response.headers.items(): - raw_data += f"{header_name}: {header_value}\r\n".encode() - - # Add empty line to separate headers from body - raw_data += b"\r\n" - - # Add body if exists body = response.get_data(as_text=False) if body: - raw_data += body + raw += body - return raw_data + return raw def deserialize_response(raw_data: bytes) -> Response: - """ - Convert raw HTTP data to a Response object. + header_end = raw_data.find(b"\r\n\r\n") + if header_end == -1: + header_end = raw_data.find(b"\n\n") + if header_end == -1: + header_data = raw_data + body = b"" + else: + header_data = raw_data[:header_end] + body = raw_data[header_end + 2 :] + else: + header_data = raw_data[:header_end] + body = raw_data[header_end + 4 :] - Args: - raw_data: The raw HTTP data as bytes. + lines = header_data.split(b"\r\n") + if len(lines) == 1 and b"\n" in lines[0]: + lines = header_data.split(b"\n") - Returns: - A Flask Response object. - """ - lines = raw_data.split(b"\r\n") + if not lines or not lines[0]: + raise ValueError("Empty HTTP response") - # Parse status line - status_line = lines[0].decode("utf-8") + status_line = lines[0].decode("utf-8", errors="ignore") parts = status_line.split(" ", 2) if len(parts) < 2: raise ValueError(f"Invalid status line: {status_line}") status_code = int(parts[1]) - # Parse headers - headers: dict[str, Any] = {} - body_start = 0 - for i in range(1, len(lines)): - line = lines[i] - if line == b"": - body_start = i + 1 - break - if b":" in line: - header_line = line.decode("utf-8") - name, value = header_line.split(":", 1) - headers[name.strip()] = value.strip() + response = Response(response=body, status=status_code) - # Extract body - body = b"" - if body_start > 0 and body_start < len(lines): - body = b"\r\n".join(lines[body_start:]) - - # Create Response object - response = Response( - response=body, - status=status_code, - headers=headers, - ) + for line in lines[1:]: + if not line: + continue + line_str = line.decode("utf-8", errors="ignore") + if ":" not in line_str: + continue + name, value = line_str.split(":", 1) + response.headers[name] = value.strip() return response diff --git a/api/core/trigger/entities/api_entities.py b/api/core/trigger/entities/api_entities.py index 641eaa4c7a..911f64d00c 100644 --- a/api/core/trigger/entities/api_entities.py +++ b/api/core/trigger/entities/api_entities.py @@ -5,13 +5,12 @@ from pydantic import BaseModel, Field from core.entities.provider_entities import ProviderConfig from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.entities.common_entities import I18nObject from core.trigger.entities.entities import ( - OAuthSchema, SubscriptionSchema, TriggerDescription, - TriggerEntity, + TriggerIdentity, TriggerParameter, - TriggerProviderIdentity, ) @@ -26,23 +25,36 @@ class TriggerProviderSubscriptionApiEntity(BaseModel): properties: dict = Field(description="The properties of the subscription") -class TriggerProviderApiEntity(BaseModel): - identity: TriggerProviderIdentity = Field(description="The identity of the trigger provider") - credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider") - oauth_schema: Optional[OAuthSchema] = Field(description="The OAuth schema of the trigger provider") - subscription_schema: Optional[SubscriptionSchema] = Field( - description="The subscription schema of the trigger provider" - ) - triggers: list[TriggerEntity] = Field(description="The triggers of the trigger provider") - - class TriggerApiEntity(BaseModel): name: str = Field(description="The name of the trigger") + identity: TriggerIdentity = Field(description="The identity of the trigger") description: TriggerDescription = Field(description="The description of the trigger") parameters: list[TriggerParameter] = Field(description="The parameters of the trigger") output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger") +class TriggerProviderApiEntity(BaseModel): + author: str = Field(..., description="The author of the trigger provider") + name: str = Field(..., description="The name of the trigger provider") + label: I18nObject = Field(..., description="The label of the trigger provider") + description: I18nObject = Field(..., description="The description of the trigger provider") + icon: Optional[str] = Field(default=None, description="The icon of the trigger provider") + icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider") + tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider") + + plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") + + credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider") + oauth_client_schema: list[ProviderConfig] = Field( + default_factory=list, description="The schema of the OAuth client" + ) + subscription_schema: Optional[SubscriptionSchema] = Field( + description="The subscription schema of the trigger provider" + ) + triggers: list[TriggerApiEntity] = Field(description="The triggers of the trigger provider") + + class SubscriptionBuilderApiEntity(BaseModel): id: str = Field(description="The id of the subscription builder") name: str = Field(description="The name of the subscription builder") diff --git a/api/core/trigger/entities/entities.py b/api/core/trigger/entities/entities.py index 5ec70943e3..84c37ccc72 100644 --- a/api/core/trigger/entities/entities.py +++ b/api/core/trigger/entities/entities.py @@ -58,6 +58,7 @@ class TriggerProviderIdentity(BaseModel): label: I18nObject = Field(..., description="The label of the trigger provider") description: I18nObject = Field(..., description="The description of the trigger provider") icon: Optional[str] = Field(default=None, description="The icon of the trigger provider") + icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider") tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider") diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py index 703e9b0d19..9d95d6f1d4 100644 --- a/api/core/trigger/provider.py +++ b/api/core/trigger/provider.py @@ -16,7 +16,7 @@ from core.plugin.entities.request import ( TriggerInvokeResponse, ) from core.plugin.impl.trigger import PluginTriggerManager -from core.trigger.entities.api_entities import TriggerProviderApiEntity +from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProviderApiEntity from core.trigger.entities.entities import ( ProviderConfig, Subscription, @@ -69,7 +69,30 @@ class PluginTriggerProviderController: """ Convert to API entity """ - return TriggerProviderApiEntity(**self.entity.model_dump()) + return TriggerProviderApiEntity( + author=self.entity.identity.author, + name=self.entity.identity.name, + label=self.entity.identity.label, + description=self.entity.identity.description, + icon=self.entity.identity.icon, + icon_dark=self.entity.identity.icon_dark, + tags=self.entity.identity.tags, + plugin_id=self.plugin_id, + plugin_unique_identifier=self.plugin_unique_identifier, + credentials_schema=self.entity.credentials_schema, + oauth_client_schema=self.entity.oauth_schema.client_schema if self.entity.oauth_schema else [], + subscription_schema=self.entity.subscription_schema, + triggers=[ + TriggerApiEntity( + name=trigger.identity.name, + identity=trigger.identity, + description=trigger.description, + parameters=trigger.parameters, + output_schema=trigger.output_schema, + ) + for trigger in self.entity.triggers + ], + ) @property def identity(self) -> TriggerProviderIdentity: @@ -173,6 +196,18 @@ class PluginTriggerProviderController: """ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] + def get_properties_schema(self) -> list[BasicProviderConfig]: + """ + Get properties schema for this provider + + :return: List of properties config schemas + """ + return ( + [x.to_basic_provider_config() for x in self.entity.subscription_schema.properties_schema.copy()] + if self.entity.subscription_schema.properties_schema + else [] + ) + def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse: """ Dispatch a trigger through plugin runtime diff --git a/api/core/trigger/utils/encryption.py b/api/core/trigger/utils/encryption.py index 0f49343b82..0a081b1bdb 100644 --- a/api/core/trigger/utils/encryption.py +++ b/api/core/trigger/utils/encryption.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Union from core.entities.provider_entities import BasicProviderConfig, ProviderConfig -from core.helper.provider_cache import TriggerProviderCredentialsCache, TriggerProviderOAuthClientParamsCache +from core.helper.provider_cache import ProviderCredentialsCache from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity @@ -10,6 +10,44 @@ from core.trigger.provider import PluginTriggerProviderController from models.trigger import TriggerSubscription +class TriggerProviderCredentialsCache(ProviderCredentialsCache): + """Cache for trigger provider credentials""" + + def __init__(self, tenant_id: str, provider_id: str, credential_id: str): + super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider_id = kwargs["provider_id"] + credential_id = kwargs["credential_id"] + return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}" + + +class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache): + """Cache for trigger provider OAuth client""" + + def __init__(self, tenant_id: str, provider_id: str): + super().__init__(tenant_id=tenant_id, provider_id=provider_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider_id = kwargs["provider_id"] + return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}" + + +class TriggerProviderPropertiesCache(ProviderCredentialsCache): + """Cache for trigger provider properties""" + + def __init__(self, tenant_id: str, provider_id: str, subscription_id: str): + super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider_id = kwargs["provider_id"] + subscription_id = kwargs["subscription_id"] + return f"trigger_properties:tenant_id:{tenant_id}:provider_id:{provider_id}:subscription_id:{subscription_id}" + + def create_trigger_provider_encrypter_for_subscription( tenant_id: str, controller: PluginTriggerProviderController, @@ -28,6 +66,24 @@ def create_trigger_provider_encrypter_for_subscription( return encrypter, cache +def create_trigger_provider_encrypter_for_properties( + tenant_id: str, + controller: PluginTriggerProviderController, + subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity], +) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: + cache = TriggerProviderPropertiesCache( + tenant_id=tenant_id, + provider_id=str(controller.get_provider_id()), + subscription_id=subscription.id, + ) + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=controller.get_properties_schema(), + cache=cache, + ) + return encrypter, cache + + def create_trigger_provider_encrypter( tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 0431956bd5..8bdded8522 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -20,6 +20,7 @@ from core.trigger.entities.api_entities import ( ) from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import ( + create_trigger_provider_encrypter_for_properties, create_trigger_provider_encrypter_for_subscription, create_trigger_provider_oauth_encrypter, ) @@ -124,12 +125,18 @@ class TriggerProviderService: if existing: raise ValueError(f"Credential name '{name}' already exists for this provider") - encrypter, _ = create_provider_encrypter( + credential_encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=provider_controller.get_credential_schema_config(credential_type), cache=NoOpProviderCredentialCache(), ) + properties_encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=provider_controller.get_properties_schema(), + cache=NoOpProviderCredentialCache(), + ) + # Create provider record db_provider = TriggerSubscription( tenant_id=tenant_id, @@ -138,8 +145,8 @@ class TriggerProviderService: endpoint_id=endpoint_id, provider_id=str(provider_id), parameters=parameters, - properties=properties, - credentials=encrypter.encrypt(dict(credentials)), + properties=properties_encrypter.encrypt(dict(properties)), + credentials=credential_encrypter.encrypt(dict(credentials)), credential_type=credential_type.value, credential_expires_at=credential_expires_at, expires_at=expires_at, @@ -448,4 +455,22 @@ class TriggerProviderService: """ with Session(db.engine, autoflush=False) as session: subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first() + if not subscription: + return None + provider_controller = TriggerManager.get_trigger_provider( + subscription.tenant_id, TriggerProviderID(subscription.provider_id) + ) + credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription( + tenant_id=subscription.tenant_id, + controller=provider_controller, + subscription=subscription, + ) + subscription.credentials = credential_encrypter.decrypt(subscription.credentials) + + properties_encrypter, _ = create_trigger_provider_encrypter_for_properties( + tenant_id=subscription.tenant_id, + controller=provider_controller, + subscription=subscription, + ) + subscription.properties = properties_encrypter.decrypt(subscription.properties) return subscription diff --git a/api/services/trigger_service.py b/api/services/trigger_service.py index 927154a149..06855dd3fc 100644 --- a/api/services/trigger_service.py +++ b/api/services/trigger_service.py @@ -7,10 +7,12 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.plugin.entities.plugin import TriggerProviderID +from core.plugin.utils.http_parser import serialize_request from core.trigger.entities.entities import TriggerEntity from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from models.account import Account, TenantAccountJoin, TenantAccountRole from models.enums import WorkflowRunTriggeredFrom from models.trigger import TriggerSubscription @@ -143,13 +145,30 @@ class TriggerService: ) if dispatch_response.triggers: - triggers = cls.select_triggers(controller, dispatch_response, provider_id, subscription) - for trigger in triggers: - cls.process_triggered_workflows( - subscription=subscription, - trigger=trigger, - request=request, - ) + # Process triggers asynchronously to avoid blocking + from tasks.trigger_processing_tasks import process_triggers_async + + # Serialize and store the request + request_id = f"trigger_request_{uuid.uuid4().hex}" + serialized_request = serialize_request(request) + storage.save(f"triggers/{request_id}", serialized_request) + + # Queue async task with just the request ID + process_triggers_async.delay( + endpoint_id=endpoint_id, + provider_id=subscription.provider_id, + subscription_id=subscription.id, + triggers=list(dispatch_response.triggers), + request_id=request_id, + ) + + logger.info( + "Queued async processing for %d triggers on endpoint %s with request_id %s", + len(dispatch_response.triggers), + endpoint_id, + request_id, + ) + return dispatch_response.response @classmethod @@ -159,7 +178,6 @@ class TriggerService: triggers = session.scalars( select(WorkflowPluginTrigger).where( WorkflowPluginTrigger.trigger_id == trigger_id, - WorkflowPluginTrigger.triggered_by == "production", # Only production triggers for now ) ).all() return list(triggers) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py new file mode 100644 index 0000000000..ca5cad8251 --- /dev/null +++ b/api/tasks/trigger_processing_tasks.py @@ -0,0 +1,140 @@ +""" +Celery tasks for async trigger processing. + +These tasks handle trigger workflow execution asynchronously +to avoid blocking the main request thread. +""" + +import logging + +from celery import shared_task +from sqlalchemy.orm import Session + +from core.plugin.entities.plugin import TriggerProviderID +from core.plugin.utils.http_parser import deserialize_request +from core.trigger.trigger_manager import TriggerManager +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.trigger import TriggerSubscription +from services.trigger_service import TriggerService + +logger = logging.getLogger(__name__) + +# Use workflow queue for trigger processing +TRIGGER_QUEUE = "triggered_workflow_dispatcher" + + +@shared_task(queue=TRIGGER_QUEUE, bind=True, max_retries=3) +def process_triggers_async( + self, + endpoint_id: str, + provider_id: str, + subscription_id: str, + triggers: list[str], + request_id: str, +) -> dict: + """ + Process triggers asynchronously. + + Args: + endpoint_id: Endpoint ID + provider_id: Provider ID + subscription_id: Subscription ID + triggers: List of triggers to process + request_id: Unique ID of the stored request + + Returns: + dict: Execution result with status and processed trigger count + """ + try: + logger.info( + "Starting async trigger processing for endpoint=%s, triggers=%s, request_id=%s", + endpoint_id, + triggers, + request_id, + ) + + # Load request from storage + try: + serialized_request = storage.load_once(f"triggers/{request_id}") + request = deserialize_request(serialized_request) + except Exception as e: + logger.exception("Failed to load request %s", request_id, exc_info=e) + return {"status": "failed", "error": f"Failed to load request: {str(e)}"} + + with Session(db.engine) as session: + # Get subscription + subscription = session.query(TriggerSubscription).filter_by(id=subscription_id).first() + if not subscription: + logger.error("Subscription not found: %s", subscription_id) + return {"status": "failed", "error": "Subscription not found"} + + # Get controller + provider_id_obj = TriggerProviderID(provider_id) + controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id_obj) + if not controller: + logger.error("Controller not found for provider: %s", provider_id) + return {"status": "failed", "error": "Controller not found"} + + # Process each trigger + processed_count = 0 + for trigger_name in triggers: + try: + trigger = controller.get_trigger(trigger_name) + if trigger is None: + logger.error( + "Trigger '%s' not found in provider '%s'", + trigger_name, + provider_id, + ) + continue + + TriggerService.process_triggered_workflows( + subscription=subscription, + trigger=trigger, + request=request, + ) + processed_count += 1 + + except Exception: + logger.exception( + "Failed to process trigger '%s' for subscription %s", + trigger_name, + subscription_id, + ) + # Continue processing other triggers even if one fails + continue + + logger.info( + "Completed async trigger processing: processed %d/%d triggers", + processed_count, + len(triggers), + ) + + # Note: Stored request is not deleted here. It should be handled by: + # 1. Storage system's lifecycle policy (e.g., S3 lifecycle rules for triggers/* prefix) + # 2. Or periodic cleanup job if using local/persistent storage + # This ensures request data is available for debugging/retry purposes + + return { + "status": "completed", + "processed_count": processed_count, + "total_count": len(triggers), + } + + except Exception as e: + logger.exception( + "Error in async trigger processing for endpoint %s", + endpoint_id, + ) + # Retry the task if not exceeded max retries + if self.request.retries < self.max_retries: + raise self.retry(exc=e, countdown=60 * (self.request.retries + 1)) + + # Note: Stored request is not deleted even on failure. See comment above for cleanup strategy. + + return { + "status": "failed", + "error": str(e), + "retries": self.request.retries, + } diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index ea3fa83af4..627b9d73da 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -15,7 +15,7 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, WebhookParameter, ) -from core.workflow.nodes.webhook import TriggerWebhookNode +from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType