mirror of https://github.com/langgenius/dify.git
refactor(trigger): update trigger provider API and clean up unused classes
- Renamed the API endpoint for trigger providers from `/workspaces/current/trigger-providers` to `/workspaces/current/triggers` for consistency. - Removed unused `TriggerProviderCredentialsCache` and `TriggerProviderOAuthClientParamsCache` classes to streamline the codebase. - Enhanced the `TriggerProviderApiEntity` to include additional properties and improved the conversion logic in `PluginTriggerProviderController`. 🤖 Generated with [Claude Code](https://claude.ai/code)
This commit is contained in:
parent
1fffc79c32
commit
e751c0c535
|
|
@ -155,7 +155,6 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
|||
|
||||
|
||||
class TriggerSubscriptionBuilderRequestLogsApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -470,7 +469,7 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-providers")
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
|
|
|
|||
|
|
@ -68,31 +68,6 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
|||
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class TriggerProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider OAuth client"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
|
||||
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,180 +1,159 @@
|
|||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
from werkzeug.datastructures import Headers
|
||||
|
||||
|
||||
def serialize_request(request: Request) -> bytes:
|
||||
"""
|
||||
Convert a Request object to raw HTTP data.
|
||||
|
||||
Args:
|
||||
request: The Request object to convert.
|
||||
|
||||
Returns:
|
||||
The raw HTTP data as bytes.
|
||||
"""
|
||||
# Start with the request line
|
||||
method = request.method
|
||||
path = request.full_path
|
||||
# Remove trailing ? if there's no query string
|
||||
path = path.removesuffix("?")
|
||||
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
|
||||
raw_data = f"{method} {path} {protocol}\r\n".encode()
|
||||
path = request.full_path.rstrip("?")
|
||||
raw = f"{method} {path} HTTP/1.1\r\n".encode()
|
||||
|
||||
# Add headers
|
||||
for header_name, header_value in request.headers.items():
|
||||
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||
for name, value in request.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
# Add empty line to separate headers from body
|
||||
raw_data += b"\r\n"
|
||||
raw += b"\r\n"
|
||||
|
||||
# Add body if exists
|
||||
body = request.get_data(as_text=False)
|
||||
if body:
|
||||
raw_data += body
|
||||
raw += body
|
||||
|
||||
return raw_data
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_request(raw_data: bytes) -> Request:
|
||||
"""
|
||||
Convert raw HTTP data to a Request object.
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
Args:
|
||||
raw_data: The raw HTTP data as bytes.
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
Returns:
|
||||
A Flask Request object.
|
||||
"""
|
||||
lines = raw_data.split(b"\r\n")
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP request")
|
||||
|
||||
# Parse request line
|
||||
request_line = lines[0].decode("utf-8")
|
||||
parts = request_line.split(" ", 2) # Split into max 3 parts
|
||||
request_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid request line: {request_line}")
|
||||
|
||||
method = parts[0]
|
||||
path = parts[1]
|
||||
full_path = parts[1]
|
||||
protocol = parts[2] if len(parts) > 2 else "HTTP/1.1"
|
||||
|
||||
# Parse headers
|
||||
if "?" in full_path:
|
||||
path, query_string = full_path.split("?", 1)
|
||||
else:
|
||||
path = full_path
|
||||
query_string = ""
|
||||
|
||||
headers = Headers()
|
||||
body_start = 0
|
||||
for i in range(1, len(lines)):
|
||||
line = lines[i]
|
||||
if line == b"":
|
||||
body_start = i + 1
|
||||
break
|
||||
if b":" in line:
|
||||
header_line = line.decode("utf-8")
|
||||
name, value = header_line.split(":", 1)
|
||||
headers[name.strip()] = value.strip()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
headers.add(name, value.strip())
|
||||
|
||||
# Extract body
|
||||
body = b""
|
||||
if body_start > 0 and body_start < len(lines):
|
||||
body = b"\r\n".join(lines[body_start:])
|
||||
host = headers.get("Host", "localhost")
|
||||
if ":" in host:
|
||||
server_name, server_port = host.rsplit(":", 1)
|
||||
else:
|
||||
server_name = host
|
||||
server_port = "80"
|
||||
|
||||
# Create environ for Request
|
||||
environ = {
|
||||
"REQUEST_METHOD": method,
|
||||
"PATH_INFO": path.split("?")[0] if "?" in path else path,
|
||||
"QUERY_STRING": path.split("?")[1] if "?" in path else "",
|
||||
"SERVER_NAME": headers.get("Host", "localhost").split(":")[0],
|
||||
"SERVER_PORT": headers.get("Host", "localhost:80").split(":")[1] if ":" in headers.get("Host", "") else "80",
|
||||
"PATH_INFO": path,
|
||||
"QUERY_STRING": query_string,
|
||||
"SERVER_NAME": server_name,
|
||||
"SERVER_PORT": server_port,
|
||||
"SERVER_PROTOCOL": protocol,
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https" if headers.get("X-Forwarded-Proto") == "https" else "http",
|
||||
"CONTENT_LENGTH": str(len(body)) if body else "0",
|
||||
"CONTENT_TYPE": headers.get("Content-Type", ""),
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
|
||||
# Add headers to environ
|
||||
for header_name, header_value in headers.items():
|
||||
env_name = f"HTTP_{header_name.upper().replace('-', '_')}"
|
||||
if header_name.upper() not in ["CONTENT-TYPE", "CONTENT-LENGTH"]:
|
||||
environ[env_name] = header_value
|
||||
if "Content-Type" in headers:
|
||||
environ["CONTENT_TYPE"] = headers.get("Content-Type")
|
||||
|
||||
if "Content-Length" in headers:
|
||||
environ["CONTENT_LENGTH"] = headers.get("Content-Length")
|
||||
elif body:
|
||||
environ["CONTENT_LENGTH"] = str(len(body))
|
||||
|
||||
for name, value in headers.items():
|
||||
if name.upper() in ("CONTENT-TYPE", "CONTENT-LENGTH"):
|
||||
continue
|
||||
env_name = f"HTTP_{name.upper().replace('-', '_')}"
|
||||
environ[env_name] = value
|
||||
|
||||
return Request(environ)
|
||||
|
||||
|
||||
def serialize_response(response: Response) -> bytes:
|
||||
"""
|
||||
Convert a Response object to raw HTTP data.
|
||||
raw = f"HTTP/1.1 {response.status}\r\n".encode()
|
||||
|
||||
Args:
|
||||
response: The Response object to convert.
|
||||
for name, value in response.headers.items():
|
||||
raw += f"{name}: {value}\r\n".encode()
|
||||
|
||||
Returns:
|
||||
The raw HTTP data as bytes.
|
||||
"""
|
||||
# Start with the status line
|
||||
protocol = "HTTP/1.1"
|
||||
status_code = response.status_code
|
||||
status_text = response.status.split(" ", 1)[1] if " " in response.status else "OK"
|
||||
raw_data = f"{protocol} {status_code} {status_text}\r\n".encode()
|
||||
raw += b"\r\n"
|
||||
|
||||
# Add headers
|
||||
for header_name, header_value in response.headers.items():
|
||||
raw_data += f"{header_name}: {header_value}\r\n".encode()
|
||||
|
||||
# Add empty line to separate headers from body
|
||||
raw_data += b"\r\n"
|
||||
|
||||
# Add body if exists
|
||||
body = response.get_data(as_text=False)
|
||||
if body:
|
||||
raw_data += body
|
||||
raw += body
|
||||
|
||||
return raw_data
|
||||
return raw
|
||||
|
||||
|
||||
def deserialize_response(raw_data: bytes) -> Response:
|
||||
"""
|
||||
Convert raw HTTP data to a Response object.
|
||||
header_end = raw_data.find(b"\r\n\r\n")
|
||||
if header_end == -1:
|
||||
header_end = raw_data.find(b"\n\n")
|
||||
if header_end == -1:
|
||||
header_data = raw_data
|
||||
body = b""
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 2 :]
|
||||
else:
|
||||
header_data = raw_data[:header_end]
|
||||
body = raw_data[header_end + 4 :]
|
||||
|
||||
Args:
|
||||
raw_data: The raw HTTP data as bytes.
|
||||
lines = header_data.split(b"\r\n")
|
||||
if len(lines) == 1 and b"\n" in lines[0]:
|
||||
lines = header_data.split(b"\n")
|
||||
|
||||
Returns:
|
||||
A Flask Response object.
|
||||
"""
|
||||
lines = raw_data.split(b"\r\n")
|
||||
if not lines or not lines[0]:
|
||||
raise ValueError("Empty HTTP response")
|
||||
|
||||
# Parse status line
|
||||
status_line = lines[0].decode("utf-8")
|
||||
status_line = lines[0].decode("utf-8", errors="ignore")
|
||||
parts = status_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid status line: {status_line}")
|
||||
|
||||
status_code = int(parts[1])
|
||||
|
||||
# Parse headers
|
||||
headers: dict[str, Any] = {}
|
||||
body_start = 0
|
||||
for i in range(1, len(lines)):
|
||||
line = lines[i]
|
||||
if line == b"":
|
||||
body_start = i + 1
|
||||
break
|
||||
if b":" in line:
|
||||
header_line = line.decode("utf-8")
|
||||
name, value = header_line.split(":", 1)
|
||||
headers[name.strip()] = value.strip()
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
# Extract body
|
||||
body = b""
|
||||
if body_start > 0 and body_start < len(lines):
|
||||
body = b"\r\n".join(lines[body_start:])
|
||||
|
||||
# Create Response object
|
||||
response = Response(
|
||||
response=body,
|
||||
status=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode("utf-8", errors="ignore")
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -5,13 +5,12 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.trigger.entities.entities import (
|
||||
OAuthSchema,
|
||||
SubscriptionSchema,
|
||||
TriggerDescription,
|
||||
TriggerEntity,
|
||||
TriggerIdentity,
|
||||
TriggerParameter,
|
||||
TriggerProviderIdentity,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -26,23 +25,36 @@ class TriggerProviderSubscriptionApiEntity(BaseModel):
|
|||
properties: dict = Field(description="The properties of the subscription")
|
||||
|
||||
|
||||
class TriggerProviderApiEntity(BaseModel):
|
||||
identity: TriggerProviderIdentity = Field(description="The identity of the trigger provider")
|
||||
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
|
||||
oauth_schema: Optional[OAuthSchema] = Field(description="The OAuth schema of the trigger provider")
|
||||
subscription_schema: Optional[SubscriptionSchema] = Field(
|
||||
description="The subscription schema of the trigger provider"
|
||||
)
|
||||
triggers: list[TriggerEntity] = Field(description="The triggers of the trigger provider")
|
||||
|
||||
|
||||
class TriggerApiEntity(BaseModel):
|
||||
name: str = Field(description="The name of the trigger")
|
||||
identity: TriggerIdentity = Field(description="The identity of the trigger")
|
||||
description: TriggerDescription = Field(description="The description of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||
|
||||
|
||||
class TriggerProviderApiEntity(BaseModel):
|
||||
author: str = Field(..., description="The author of the trigger provider")
|
||||
name: str = Field(..., description="The name of the trigger provider")
|
||||
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
|
||||
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
|
||||
oauth_client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth client"
|
||||
)
|
||||
subscription_schema: Optional[SubscriptionSchema] = Field(
|
||||
description="The subscription schema of the trigger provider"
|
||||
)
|
||||
triggers: list[TriggerApiEntity] = Field(description="The triggers of the trigger provider")
|
||||
|
||||
|
||||
class SubscriptionBuilderApiEntity(BaseModel):
|
||||
id: str = Field(description="The id of the subscription builder")
|
||||
name: str = Field(description="The name of the subscription builder")
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ class TriggerProviderIdentity(BaseModel):
|
|||
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||
icon_dark: Optional[str] = Field(default=None, description="The dark icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from core.plugin.entities.request import (
|
|||
TriggerInvokeResponse,
|
||||
)
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity
|
||||
from core.trigger.entities.api_entities import TriggerApiEntity, TriggerProviderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
|
|
@ -69,7 +69,30 @@ class PluginTriggerProviderController:
|
|||
"""
|
||||
Convert to API entity
|
||||
"""
|
||||
return TriggerProviderApiEntity(**self.entity.model_dump())
|
||||
return TriggerProviderApiEntity(
|
||||
author=self.entity.identity.author,
|
||||
name=self.entity.identity.name,
|
||||
label=self.entity.identity.label,
|
||||
description=self.entity.identity.description,
|
||||
icon=self.entity.identity.icon,
|
||||
icon_dark=self.entity.identity.icon_dark,
|
||||
tags=self.entity.identity.tags,
|
||||
plugin_id=self.plugin_id,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
credentials_schema=self.entity.credentials_schema,
|
||||
oauth_client_schema=self.entity.oauth_schema.client_schema if self.entity.oauth_schema else [],
|
||||
subscription_schema=self.entity.subscription_schema,
|
||||
triggers=[
|
||||
TriggerApiEntity(
|
||||
name=trigger.identity.name,
|
||||
identity=trigger.identity,
|
||||
description=trigger.description,
|
||||
parameters=trigger.parameters,
|
||||
output_schema=trigger.output_schema,
|
||||
)
|
||||
for trigger in self.entity.triggers
|
||||
],
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self) -> TriggerProviderIdentity:
|
||||
|
|
@ -173,6 +196,18 @@ class PluginTriggerProviderController:
|
|||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def get_properties_schema(self) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
Get properties schema for this provider
|
||||
|
||||
:return: List of properties config schemas
|
||||
"""
|
||||
return (
|
||||
[x.to_basic_provider_config() for x in self.entity.subscription_schema.properties_schema.copy()]
|
||||
if self.entity.subscription_schema.properties_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch a trigger through plugin runtime
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections.abc import Mapping
|
|||
from typing import Union
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
|
||||
from core.helper.provider_cache import TriggerProviderCredentialsCache, TriggerProviderOAuthClientParamsCache
|
||||
from core.helper.provider_cache import ProviderCredentialsCache
|
||||
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||
|
|
@ -10,6 +10,44 @@ from core.trigger.provider import PluginTriggerProviderController
|
|||
from models.trigger import TriggerSubscription
|
||||
|
||||
|
||||
class TriggerProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class TriggerProviderOAuthClientParamsCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider OAuth client"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
|
||||
|
||||
|
||||
class TriggerProviderPropertiesCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider properties"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str, subscription_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, subscription_id=subscription_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
subscription_id = kwargs["subscription_id"]
|
||||
return f"trigger_properties:tenant_id:{tenant_id}:provider_id:{provider_id}:subscription_id:{subscription_id}"
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
|
|
@ -28,6 +66,24 @@ def create_trigger_provider_encrypter_for_subscription(
|
|||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
subscription: Union[TriggerSubscription, TriggerProviderSubscriptionApiEntity],
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = TriggerProviderPropertiesCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
subscription_id=subscription.id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_properties_schema(),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from core.trigger.entities.api_entities import (
|
|||
)
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import (
|
||||
create_trigger_provider_encrypter_for_properties,
|
||||
create_trigger_provider_encrypter_for_subscription,
|
||||
create_trigger_provider_oauth_encrypter,
|
||||
)
|
||||
|
|
@ -124,12 +125,18 @@ class TriggerProviderService:
|
|||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
credential_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
properties_encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_properties_schema(),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Create provider record
|
||||
db_provider = TriggerSubscription(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -138,8 +145,8 @@ class TriggerProviderService:
|
|||
endpoint_id=endpoint_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=parameters,
|
||||
properties=properties,
|
||||
credentials=encrypter.encrypt(dict(credentials)),
|
||||
properties=properties_encrypter.encrypt(dict(properties)),
|
||||
credentials=credential_encrypter.encrypt(dict(credentials)),
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
|
|
@ -448,4 +455,22 @@ class TriggerProviderService:
|
|||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
||||
if not subscription:
|
||||
return None
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
subscription.tenant_id, TriggerProviderID(subscription.provider_id)
|
||||
)
|
||||
credential_encrypter, _ = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.credentials = credential_encrypter.decrypt(subscription.credentials)
|
||||
|
||||
properties_encrypter, _ = create_trigger_provider_encrypter_for_properties(
|
||||
tenant_id=subscription.tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=subscription,
|
||||
)
|
||||
subscription.properties = properties_encrypter.decrypt(subscription.properties)
|
||||
return subscription
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.utils.http_parser import serialize_request
|
||||
from core.trigger.entities.entities import TriggerEntity
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.trigger import TriggerSubscription
|
||||
|
|
@ -143,13 +145,30 @@ class TriggerService:
|
|||
)
|
||||
|
||||
if dispatch_response.triggers:
|
||||
triggers = cls.select_triggers(controller, dispatch_response, provider_id, subscription)
|
||||
for trigger in triggers:
|
||||
cls.process_triggered_workflows(
|
||||
subscription=subscription,
|
||||
trigger=trigger,
|
||||
request=request,
|
||||
)
|
||||
# Process triggers asynchronously to avoid blocking
|
||||
from tasks.trigger_processing_tasks import process_triggers_async
|
||||
|
||||
# Serialize and store the request
|
||||
request_id = f"trigger_request_{uuid.uuid4().hex}"
|
||||
serialized_request = serialize_request(request)
|
||||
storage.save(f"triggers/{request_id}", serialized_request)
|
||||
|
||||
# Queue async task with just the request ID
|
||||
process_triggers_async.delay(
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=subscription.provider_id,
|
||||
subscription_id=subscription.id,
|
||||
triggers=list(dispatch_response.triggers),
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Queued async processing for %d triggers on endpoint %s with request_id %s",
|
||||
len(dispatch_response.triggers),
|
||||
endpoint_id,
|
||||
request_id,
|
||||
)
|
||||
|
||||
return dispatch_response.response
|
||||
|
||||
@classmethod
|
||||
|
|
@ -159,7 +178,6 @@ class TriggerService:
|
|||
triggers = session.scalars(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.trigger_id == trigger_id,
|
||||
WorkflowPluginTrigger.triggered_by == "production", # Only production triggers for now
|
||||
)
|
||||
).all()
|
||||
return list(triggers)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
Celery tasks for async trigger processing.
|
||||
|
||||
These tasks handle trigger workflow execution asynchronously
|
||||
to avoid blocking the main request thread.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.utils.http_parser import deserialize_request
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Use workflow queue for trigger processing
|
||||
TRIGGER_QUEUE = "triggered_workflow_dispatcher"
|
||||
|
||||
|
||||
@shared_task(queue=TRIGGER_QUEUE, bind=True, max_retries=3)
|
||||
def process_triggers_async(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
provider_id: str,
|
||||
subscription_id: str,
|
||||
triggers: list[str],
|
||||
request_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Process triggers asynchronously.
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
provider_id: Provider ID
|
||||
subscription_id: Subscription ID
|
||||
triggers: List of triggers to process
|
||||
request_id: Unique ID of the stored request
|
||||
|
||||
Returns:
|
||||
dict: Execution result with status and processed trigger count
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Starting async trigger processing for endpoint=%s, triggers=%s, request_id=%s",
|
||||
endpoint_id,
|
||||
triggers,
|
||||
request_id,
|
||||
)
|
||||
|
||||
# Load request from storage
|
||||
try:
|
||||
serialized_request = storage.load_once(f"triggers/{request_id}")
|
||||
request = deserialize_request(serialized_request)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to load request %s", request_id, exc_info=e)
|
||||
return {"status": "failed", "error": f"Failed to load request: {str(e)}"}
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get subscription
|
||||
subscription = session.query(TriggerSubscription).filter_by(id=subscription_id).first()
|
||||
if not subscription:
|
||||
logger.error("Subscription not found: %s", subscription_id)
|
||||
return {"status": "failed", "error": "Subscription not found"}
|
||||
|
||||
# Get controller
|
||||
provider_id_obj = TriggerProviderID(provider_id)
|
||||
controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id_obj)
|
||||
if not controller:
|
||||
logger.error("Controller not found for provider: %s", provider_id)
|
||||
return {"status": "failed", "error": "Controller not found"}
|
||||
|
||||
# Process each trigger
|
||||
processed_count = 0
|
||||
for trigger_name in triggers:
|
||||
try:
|
||||
trigger = controller.get_trigger(trigger_name)
|
||||
if trigger is None:
|
||||
logger.error(
|
||||
"Trigger '%s' not found in provider '%s'",
|
||||
trigger_name,
|
||||
provider_id,
|
||||
)
|
||||
continue
|
||||
|
||||
TriggerService.process_triggered_workflows(
|
||||
subscription=subscription,
|
||||
trigger=trigger,
|
||||
request=request,
|
||||
)
|
||||
processed_count += 1
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to process trigger '%s' for subscription %s",
|
||||
trigger_name,
|
||||
subscription_id,
|
||||
)
|
||||
# Continue processing other triggers even if one fails
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Completed async trigger processing: processed %d/%d triggers",
|
||||
processed_count,
|
||||
len(triggers),
|
||||
)
|
||||
|
||||
# Note: Stored request is not deleted here. It should be handled by:
|
||||
# 1. Storage system's lifecycle policy (e.g., S3 lifecycle rules for triggers/* prefix)
|
||||
# 2. Or periodic cleanup job if using local/persistent storage
|
||||
# This ensures request data is available for debugging/retry purposes
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"processed_count": processed_count,
|
||||
"total_count": len(triggers),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error in async trigger processing for endpoint %s",
|
||||
endpoint_id,
|
||||
)
|
||||
# Retry the task if not exceeded max retries
|
||||
if self.request.retries < self.max_retries:
|
||||
raise self.retry(exc=e, countdown=60 * (self.request.retries + 1))
|
||||
|
||||
# Note: Stored request is not deleted even on failure. See comment above for cleanup strategy.
|
||||
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"retries": self.request.retries,
|
||||
}
|
||||
|
|
@ -15,7 +15,7 @@ from core.workflow.nodes.trigger_webhook.entities import (
|
|||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from core.workflow.nodes.webhook import TriggerWebhookNode
|
||||
from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
|
|
|||
Loading…
Reference in New Issue