mirror of https://github.com/langgenius/dify.git
feat(trigger): introduce subscription builder and enhance trigger management
- Refactor trigger provider classes to improve naming consistency, including renaming classes for subscription management - Implement new TriggerSubscriptionBuilderService for creating and verifying subscription builders - Update API endpoints to support subscription builder creation and verification - Enhance data models to include new attributes for subscription builders - Remove the deprecated TriggerSubscriptionValidationService to streamline the codebase Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
694197a701
commit
afd8989150
|
|
@ -15,6 +15,7 @@ from libs.login import current_user, login_required
|
|||
from models.account import Account
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ class TriggerProviderListApi(Resource):
|
|||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
class TriggerProviderSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -54,7 +55,7 @@ class TriggerSubscriptionListApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionsAddApi(Resource):
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -67,31 +68,23 @@ class TriggerSubscriptionsAddApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("expires_at", type=int, required=False, nullable=True, location="json", default=-1)
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Parse credential type
|
||||
try:
|
||||
credential_type = CredentialType(args["credential_type"])
|
||||
except ValueError:
|
||||
raise BadRequest(f"Invalid credential_type. Must be one of: {[t.value for t in CredentialType]}")
|
||||
|
||||
result = TriggerProviderService.add_trigger_provider(
|
||||
credentials = args.get("credentials", {})
|
||||
credential_type = CredentialType.API_KEY if credentials else CredentialType.UNAUTHORIZED
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
credentials=args["credentials"],
|
||||
name=args.get("name"),
|
||||
expires_at=args.get("expires_at", -1),
|
||||
credential_expires_at=-1,
|
||||
expires_at=-1,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
|
|
@ -99,6 +92,58 @@ class TriggerSubscriptionsAddApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.verify_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
TriggerSubscriptionBuilderService.build_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
)
|
||||
return 200
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionsDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -239,17 +284,17 @@ class TriggerOAuthCallbackApi(Resource):
|
|||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Save OAuth credentials to database
|
||||
TriggerProviderService.add_trigger_provider(
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
credentials=credentials,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
credentials=dict(credentials),
|
||||
credential_expires_at=expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback?subscription_id={subscription_builder.id}")
|
||||
|
||||
|
||||
class TriggerOAuthRefreshTokenApi(Resource):
|
||||
|
|
@ -380,11 +425,18 @@ class TriggerOAuthClientManageApi(Resource):
|
|||
|
||||
# Trigger provider endpoints
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-providers")
|
||||
api.add_resource(TriggerProviderSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/list"
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/subscriptions/<path:provider>/create-builder",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionsAddApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/add"
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/subscriptions/<path:provider>/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/subscriptions/<path:provider>/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionsDeleteApi,
|
||||
|
|
@ -393,13 +445,11 @@ api.add_resource(
|
|||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/subscriptions/<path:provider>/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(
|
||||
TriggerOAuthRefreshTokenApi,
|
||||
"/workspaces/current/trigger-provider/subscriptions/<path:subscription_id>/oauth/refresh",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
|
||||
)
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask import jsonify, request
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_subscription_validation_service import TriggerSubscriptionValidationService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,7 +29,7 @@ def trigger_endpoint(endpoint_id: str):
|
|||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionValidationService.process_validating_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
|
|
|
|||
|
|
@ -210,12 +210,15 @@ class PluginTriggerProviderEntity(BaseModel):
|
|||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
UNAUTHORIZED = "unauthorized"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
elif self == CredentialType.UNAUTHORIZED:
|
||||
return "UNAUTHORIZED"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
|
|
@ -236,5 +239,7 @@ class CredentialType(enum.StrEnum):
|
|||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
elif type_name == "unauthorized":
|
||||
return cls.UNAUTHORIZED
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
|
|
|||
|
|
@ -248,14 +248,17 @@ class PluginTriggerDispatchResponse(BaseModel):
|
|||
triggers: list[str]
|
||||
raw_http_response: str
|
||||
|
||||
|
||||
class TriggerSubscriptionResponse(BaseModel):
|
||||
subscription: dict[str, Any]
|
||||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
valid: bool
|
||||
message: str
|
||||
error: str
|
||||
|
||||
|
||||
class TriggerDispatchResponse:
|
||||
triggers: list[str]
|
||||
response: Response
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request
|
||||
|
|
@ -84,10 +85,10 @@ class PluginTriggerManager(BasePluginClient):
|
|||
user_id: str,
|
||||
provider: str,
|
||||
trigger: str,
|
||||
credentials: dict[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
parameters: dict[str, Any],
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerInvokeResponse:
|
||||
"""
|
||||
Invoke a trigger with the given parameters.
|
||||
|
|
@ -121,7 +122,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
raise ValueError("No response received from plugin daemon for invoke trigger")
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||
) -> TriggerValidateProviderCredentialsResponse:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
|
|
@ -155,7 +156,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
subscription: dict[str, Any],
|
||||
subscription: Mapping[str, Any],
|
||||
request: Request,
|
||||
) -> TriggerDispatchResponse:
|
||||
"""
|
||||
|
|
@ -194,9 +195,9 @@ class PluginTriggerManager(BasePluginClient):
|
|||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: dict[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
endpoint: str,
|
||||
parameters: dict[str, Any],
|
||||
parameters: Mapping[str, Any],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Subscribe to a trigger.
|
||||
|
|
@ -233,7 +234,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: dict[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Unsubscribe from a trigger.
|
||||
|
|
@ -269,7 +270,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
user_id: str,
|
||||
provider: str,
|
||||
subscription: Subscription,
|
||||
credentials: dict[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
) -> TriggerSubscriptionResponse:
|
||||
"""
|
||||
Refresh a trigger subscription.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from core.entities.provider_entities import ProviderConfig
|
|||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import (
|
||||
OAuthSchema,
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerDescription,
|
||||
TriggerEntity,
|
||||
|
|
@ -40,27 +39,4 @@ class TriggerApiEntity(BaseModel):
|
|||
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||
|
||||
class SubscriptionValidation(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
provider_id: str
|
||||
endpoint: str
|
||||
parameters: dict
|
||||
properties: dict
|
||||
credentials: dict
|
||||
credential_type: str
|
||||
credential_expires_at: int
|
||||
expires_at: int
|
||||
|
||||
def to_subscription(self) -> Subscription:
|
||||
return Subscription(
|
||||
expires_at=self.expires_at,
|
||||
endpoint=self.endpoint,
|
||||
parameters=self.parameters,
|
||||
properties=self.properties,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]
|
||||
|
|
|
|||
|
|
@ -115,6 +115,18 @@ class SubscriptionSchema(BaseModel):
|
|||
description="The configuration schema stored in the subscription entity",
|
||||
)
|
||||
|
||||
def get_default_parameters(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters from the parameters schema"""
|
||||
if not self.parameters_schema:
|
||||
return {}
|
||||
return {param.name: param.default for param in self.parameters_schema if param.default}
|
||||
|
||||
def get_default_properties(self) -> Mapping[str, Any]:
|
||||
"""Get the default properties from the properties schema"""
|
||||
if not self.properties_schema:
|
||||
return {}
|
||||
return {prop.name: prop.default for prop in self.properties_schema if prop.default}
|
||||
|
||||
|
||||
class TriggerProviderEntity(BaseModel):
|
||||
"""
|
||||
|
|
@ -148,13 +160,7 @@ class Subscription(BaseModel):
|
|||
)
|
||||
|
||||
endpoint: str = Field(..., description="The webhook endpoint URL allocated by Dify for receiving events")
|
||||
|
||||
parameters: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="""The parameters of the subscription, this is the creation parameters.
|
||||
Only available when creating a new subscription by credentials(auto subscription), not manual subscription""",
|
||||
)
|
||||
properties: dict[str, Any] = Field(
|
||||
properties: Mapping[str, Any] = Field(
|
||||
..., description="Subscription data containing all properties and provider-specific information"
|
||||
)
|
||||
|
||||
|
|
@ -177,10 +183,43 @@ class Unsubscription(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class RequestLog(BaseModel):
|
||||
id: str
|
||||
endpoint: str
|
||||
request: dict
|
||||
response: dict
|
||||
created_at: str
|
||||
|
||||
|
||||
class SubscriptionBuilder(BaseModel):
|
||||
id: str
|
||||
name: str | None = None
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
provider_id: str
|
||||
endpoint_id: str
|
||||
parameters: Mapping[str, Any]
|
||||
properties: Mapping[str, Any]
|
||||
credentials: Mapping[str, str]
|
||||
credential_type: str | None = None
|
||||
credential_expires_at: int | None = None
|
||||
expires_at: int
|
||||
|
||||
def to_subscription(self) -> Subscription:
|
||||
return Subscription(
|
||||
expires_at=self.expires_at,
|
||||
endpoint=self.endpoint_id,
|
||||
parameters=self.parameters,
|
||||
properties=self.properties,
|
||||
)
|
||||
|
||||
|
||||
# Export all entities
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
"RequestLog",
|
||||
"Subscription",
|
||||
"SubscriptionBuilder",
|
||||
"TriggerDescription",
|
||||
"TriggerEntity",
|
||||
"TriggerIdentity",
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ Trigger Provider Controller for managing trigger providers
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ from core.trigger.entities.api_entities import TriggerProviderApiEntity
|
|||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerEntity,
|
||||
TriggerProviderEntity,
|
||||
TriggerProviderIdentity,
|
||||
|
|
@ -91,18 +93,17 @@ class PluginTriggerProviderController:
|
|||
return trigger
|
||||
return None
|
||||
|
||||
def get_subscription_schema(self) -> list[ProviderConfig]:
|
||||
def get_subscription_schema(self) -> SubscriptionSchema:
|
||||
"""
|
||||
Get subscription schema for this provider
|
||||
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
# Return the parameters schema from the subscription schema
|
||||
if self.entity.subscription_schema and self.entity.subscription_schema.parameters_schema:
|
||||
return self.entity.subscription_schema.parameters_schema
|
||||
return []
|
||||
return self.entity.subscription_schema
|
||||
|
||||
def validate_credentials(self, credentials: dict) -> TriggerValidateProviderCredentialsResponse:
|
||||
def validate_credentials(
|
||||
self, user_id: str, credentials: Mapping[str, str]
|
||||
) -> TriggerValidateProviderCredentialsResponse:
|
||||
"""
|
||||
Validate credentials against schema
|
||||
|
||||
|
|
@ -123,7 +124,7 @@ class PluginTriggerProviderController:
|
|||
provider_id = self.get_provider_id()
|
||||
return manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id="system", # System validation
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
)
|
||||
|
|
@ -153,6 +154,7 @@ class PluginTriggerProviderController:
|
|||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
|
|
@ -168,7 +170,7 @@ class PluginTriggerProviderController:
|
|||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def dispatch(self,user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
|
||||
def dispatch(self, user_id: str, request: Request, subscription: Subscription) -> TriggerDispatchResponse:
|
||||
"""
|
||||
Dispatch a trigger through plugin runtime
|
||||
|
||||
|
|
@ -193,8 +195,8 @@ class PluginTriggerProviderController:
|
|||
self,
|
||||
user_id: str,
|
||||
trigger_name: str,
|
||||
parameters: dict,
|
||||
credentials: dict,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
|
|
@ -223,7 +225,9 @@ class PluginTriggerProviderController:
|
|||
parameters=parameters,
|
||||
)
|
||||
|
||||
def subscribe_trigger(self, user_id: str, endpoint: str, parameters: dict, credentials: dict) -> Subscription:
|
||||
def subscribe_trigger(
|
||||
self, user_id: str, endpoint: str, parameters: Mapping[str, Any], credentials: Mapping[str, str]
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger through plugin runtime
|
||||
|
||||
|
|
@ -247,7 +251,9 @@ class PluginTriggerProviderController:
|
|||
|
||||
return Subscription.model_validate(response.subscription)
|
||||
|
||||
def unsubscribe_trigger(self, user_id: str, subscription: Subscription, credentials: dict) -> Unsubscription:
|
||||
def unsubscribe_trigger(
|
||||
self, user_id: str, subscription: Subscription, credentials: Mapping[str, str]
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger through plugin runtime
|
||||
|
||||
|
|
@ -269,7 +275,7 @@ class PluginTriggerProviderController:
|
|||
|
||||
return Unsubscription.model_validate(response.subscription)
|
||||
|
||||
def refresh_trigger(self, subscription: Subscription, credentials: dict) -> Subscription:
|
||||
def refresh_trigger(self, subscription: Subscription, credentials: Mapping[str, str]) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription through plugin runtime
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ Trigger Manager for loading and managing trigger providers and triggers
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request
|
||||
|
||||
|
|
@ -12,8 +13,8 @@ from core.plugin.entities.plugin_daemon import CredentialType
|
|||
from core.plugin.entities.request import TriggerInvokeResponse
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
SubscriptionSchema,
|
||||
TriggerEntity,
|
||||
Unsubscription,
|
||||
)
|
||||
|
|
@ -116,19 +117,20 @@ class TriggerManager:
|
|||
|
||||
@classmethod
|
||||
def validate_trigger_credentials(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, credentials: dict
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, user_id: str, credentials: Mapping[str, str]
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate trigger provider credentials
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:param user_id: User ID
|
||||
:param credentials: Credentials to validate
|
||||
:return: Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
validation_result = provider.validate_credentials(credentials)
|
||||
validation_result = provider.validate_credentials(user_id, credentials)
|
||||
return validation_result.valid, validation_result.message if not validation_result.valid else ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
|
@ -140,8 +142,8 @@ class TriggerManager:
|
|||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
parameters: dict,
|
||||
credentials: dict,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
request: Request,
|
||||
) -> TriggerInvokeResponse:
|
||||
|
|
@ -171,8 +173,8 @@ class TriggerManager:
|
|||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
endpoint: str,
|
||||
parameters: dict,
|
||||
credentials: dict,
|
||||
parameters: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger (e.g., register webhook)
|
||||
|
|
@ -197,7 +199,7 @@ class TriggerManager:
|
|||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription: Subscription,
|
||||
credentials: dict,
|
||||
credentials: Mapping[str, str],
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger
|
||||
|
|
@ -213,7 +215,7 @@ class TriggerManager:
|
|||
return provider.unsubscribe_trigger(user_id=user_id, subscription=subscription, credentials=credentials)
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]:
|
||||
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> SubscriptionSchema:
|
||||
"""
|
||||
Get provider subscription schema
|
||||
|
||||
|
|
@ -228,9 +230,8 @@ class TriggerManager:
|
|||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
subscription: Subscription,
|
||||
credentials: dict,
|
||||
credentials: Mapping[str, str],
|
||||
) -> Subscription:
|
||||
"""
|
||||
Refresh a trigger subscription
|
||||
|
|
@ -242,7 +243,7 @@ class TriggerManager:
|
|||
:param credentials: Provider credentials
|
||||
:return: Refreshed subscription result
|
||||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(trigger_name, subscription, credentials)
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).refresh_trigger(subscription, credentials)
|
||||
|
||||
|
||||
# Export
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from configs import dify_config
|
||||
|
||||
|
||||
def parse_endpoint_id(endpoint_id: str) -> str:
|
||||
return f"{dify_config.CONSOLE_API_URL}/console/api/trigger/endpoint/{endpoint_id}"
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""Optimize trigger provider endpoint index
|
||||
|
||||
Revision ID: 9d83760807c5
|
||||
Revises: 9ee7d347f4c1
|
||||
Create Date: 2025-09-01 12:42:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '9d83760807c5'
|
||||
down_revision = '9ee7d347f4c1'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Drop the old unique constraint on endpoint
|
||||
with op.batch_alter_table('trigger_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint('unique_trigger_provider_endpoint', type_='unique')
|
||||
|
||||
# Create a new unique index on endpoint for O(1) lookup
|
||||
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint'], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Drop the new index
|
||||
with op.batch_alter_table('trigger_providers', schema=None) as batch_op:
|
||||
batch_op.drop_index('idx_trigger_providers_endpoint')
|
||||
|
||||
# Recreate the old unique constraint
|
||||
batch_op.create_unique_constraint('unique_trigger_provider_endpoint', ['endpoint'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
|||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.api_entities import TriggerProviderSubscriptionApiEntity
|
||||
from core.trigger.entities.entities import Subscription
|
||||
from core.trigger.utils.endpoint import parse_endpoint_id
|
||||
from models.base import Base
|
||||
from models.types import StringUUID
|
||||
|
||||
|
|
@ -22,10 +23,13 @@ class TriggerSubscription(Base):
|
|||
|
||||
__tablename__ = "trigger_subscriptions"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trigger_subscription_pkey"),
|
||||
Index("idx_trigger_subscriptions_tenant_provider", "tenant_id", "provider_id"),
|
||||
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_subscription"),
|
||||
UniqueConstraint("endpoint", name="unique_trigger_subscription_endpoint"),
|
||||
sa.PrimaryKeyConstraint("id", name="trigger_provider_pkey"),
|
||||
Index("idx_trigger_providers_tenant_provider", "tenant_id", "provider_id"),
|
||||
# Primary index for O(1) lookup by endpoint
|
||||
Index("idx_trigger_providers_endpoint", "endpoint_id", unique=True),
|
||||
# Composite index for tenant-specific queries (optional, kept for compatibility)
|
||||
Index("idx_trigger_providers_tenant_endpoint", "tenant_id", "endpoint_id"),
|
||||
UniqueConstraint("tenant_id", "provider_id", "name", name="unique_trigger_provider"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
|
|
@ -35,7 +39,7 @@ class TriggerSubscription(Base):
|
|||
provider_id: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
properties: Mapped[dict] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
|
||||
|
|
@ -63,8 +67,7 @@ class TriggerSubscription(Base):
|
|||
def to_entity(self) -> Subscription:
|
||||
return Subscription(
|
||||
expires_at=self.expires_at,
|
||||
endpoint=self.endpoint,
|
||||
parameters=self.parameters,
|
||||
endpoint=parse_endpoint_id(self.endpoint_id),
|
||||
properties=self.properties,
|
||||
)
|
||||
|
||||
|
|
@ -77,6 +80,7 @@ class TriggerSubscription(Base):
|
|||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
|
||||
# system level trigger oauth client params
|
||||
class TriggerOAuthSystemClient(Base):
|
||||
__tablename__ = "trigger_oauth_system_clients"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
|
|
@ -16,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType
|
|||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import (
|
||||
SubscriptionValidation,
|
||||
TriggerProviderApiEntity,
|
||||
TriggerProviderSubscriptionApiEntity,
|
||||
)
|
||||
|
|
@ -36,6 +34,9 @@ logger = logging.getLogger(__name__)
|
|||
class TriggerProviderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
@classmethod
|
||||
|
|
@ -73,10 +74,14 @@ class TriggerProviderService:
|
|||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
name: str,
|
||||
provider_id: TriggerProviderID,
|
||||
endpoint_id: str,
|
||||
credential_type: CredentialType,
|
||||
credentials: dict,
|
||||
name: Optional[str] = None,
|
||||
parameters: Mapping[str, Any],
|
||||
properties: Mapping[str, Any],
|
||||
credentials: Mapping[str, str],
|
||||
credential_expires_at: int = -1,
|
||||
expires_at: int = -1,
|
||||
) -> dict:
|
||||
"""
|
||||
|
|
@ -93,7 +98,7 @@ class TriggerProviderService:
|
|||
"""
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
|
|
@ -110,23 +115,14 @@ class TriggerProviderService:
|
|||
f"reached for {provider_id}"
|
||||
)
|
||||
|
||||
# Generate name if not provided
|
||||
if not name:
|
||||
name = cls._generate_provider_name(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -138,10 +134,14 @@ class TriggerProviderService:
|
|||
db_provider = TriggerSubscription(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
credentials=encrypter.encrypt(credentials),
|
||||
name=name,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=provider_id,
|
||||
parameters=parameters,
|
||||
properties=properties,
|
||||
credentials=encrypter.encrypt(dict(credentials)),
|
||||
credential_type=credential_type.value,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
|
@ -154,70 +154,6 @@ class TriggerProviderService:
|
|||
logger.exception("Failed to add trigger provider")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def update_trigger_provider(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
credentials: Optional[dict] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update an existing trigger provider's credentials or name.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param subscription_id: Subscription instance ID
|
||||
:param credentials: New credentials (optional)
|
||||
:param name: New name (optional)
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first()
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(db_provider.provider_id)
|
||||
)
|
||||
|
||||
if credentials:
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_subscription(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
subscription=db_provider,
|
||||
)
|
||||
# Handle hidden values
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
new_credentials = {
|
||||
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
db_provider.credentials = encrypter.encrypt(new_credentials)
|
||||
cache.delete()
|
||||
|
||||
# Update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name)
|
||||
.filter(TriggerSubscription.id != subscription_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists")
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_provider(cls, tenant_id: str, subscription_id: str) -> dict:
|
||||
"""
|
||||
|
|
@ -505,59 +441,6 @@ class TriggerProviderService:
|
|||
)
|
||||
return custom_client is not None
|
||||
|
||||
@classmethod
|
||||
def _generate_provider_name(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credential_type: CredentialType,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique name for a provider credential instance.
|
||||
|
||||
:param session: Database session
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider: Provider identifier
|
||||
:param credential_type: Credential type
|
||||
:return: Generated name
|
||||
"""
|
||||
try:
|
||||
db_providers = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(desc(TriggerSubscription.created_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get base name
|
||||
base_name = credential_type.get_name()
|
||||
|
||||
# Find existing numbered names
|
||||
pattern = rf"^{re.escape(base_name)}\s+(\d+)$"
|
||||
numbers = []
|
||||
|
||||
for db_provider in db_providers:
|
||||
if db_provider.name:
|
||||
match = re.match(pattern, db_provider.name.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
# Generate next number
|
||||
if not numbers:
|
||||
return f"{base_name} 1"
|
||||
|
||||
max_number = max(numbers)
|
||||
return f"{base_name} {max_number + 1}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error generating provider name")
|
||||
return f"{credential_type.get_name()} 1"
|
||||
|
||||
@classmethod
|
||||
def get_subscription_by_endpoint(cls, endpoint_id: str) -> TriggerSubscription | None:
|
||||
"""
|
||||
|
|
@ -566,15 +449,3 @@ class TriggerProviderService:
|
|||
with Session(db.engine, autoflush=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first()
|
||||
return subscription
|
||||
|
||||
@classmethod
|
||||
def get_subscription_validation(cls, endpoint_id: str) -> SubscriptionValidation | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
cache_key = f"trigger:subscription:validation:endpoint:{endpoint_id}"
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionValidation.model_validate(json.loads(subscription_cache))
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import (
|
||||
RequestLog,
|
||||
SubscriptionBuilder,
|
||||
)
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
##########################
|
||||
# Trigger provider
|
||||
##########################
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 10
|
||||
|
||||
##########################
|
||||
# Validation endpoint
|
||||
##########################
|
||||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 30 * 60 * 1000
|
||||
|
||||
@classmethod
|
||||
def encode_cache_key(cls, subscription_id: str) -> str:
|
||||
return f"trigger:subscription:validation:{subscription_id}"
|
||||
|
||||
@classmethod
|
||||
def verify_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
) -> None:
|
||||
"""Verify a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
|
||||
|
||||
@classmethod
|
||||
def build_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
) -> None:
|
||||
"""Build a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_builder = cls.get_subscription_builder(subscription_builder_id)
|
||||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if subscription_builder.name is None:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription_builder.properties,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# automatically create
|
||||
subscription = TriggerManager.subscribe_trigger(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
)
|
||||
|
||||
TriggerProviderService.add_trigger_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
name=subscription_builder.name,
|
||||
provider_id=provider_id,
|
||||
endpoint_id=subscription_builder.endpoint_id,
|
||||
parameters=subscription_builder.parameters,
|
||||
properties=subscription.properties,
|
||||
credentials=subscription_builder.credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=subscription_builder.credential_expires_at or -1,
|
||||
expires_at=subscription_builder.expires_at,
|
||||
)
|
||||
|
||||
cls.delete_trigger_subscription_builder(subscription_builder_id)
|
||||
|
||||
@classmethod
|
||||
def create_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credentials: Mapping[str, str],
|
||||
credential_type: CredentialType,
|
||||
credential_expires_at: int,
|
||||
expires_at: int,
|
||||
) -> SubscriptionBuilder:
|
||||
"""
|
||||
Add a new trigger subscription validation.
|
||||
"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
subscription_schema = provider_controller.get_subscription_schema()
|
||||
subscription_id = str(uuid.uuid4())
|
||||
subscription_builder = SubscriptionBuilder(
|
||||
id=subscription_id,
|
||||
name="",
|
||||
endpoint_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=subscription_schema.get_default_parameters(),
|
||||
properties=subscription_schema.get_default_properties(),
|
||||
credentials=credentials,
|
||||
credential_type=credential_type,
|
||||
credential_expires_at=credential_expires_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.setex(
|
||||
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
||||
)
|
||||
return subscription_builder
|
||||
|
||||
@classmethod
|
||||
def update_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_builder: SubscriptionBuilder,
|
||||
) -> SubscriptionBuilder:
|
||||
"""
|
||||
Update a trigger subscription validation.
|
||||
"""
|
||||
subscription_id = subscription_builder.id
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
|
||||
redis_client.setex(
|
||||
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
||||
)
|
||||
return subscription_builder
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_subscription_builder(cls, subscription_id: str) -> None:
|
||||
"""
|
||||
Delete a trigger subscription validation.
|
||||
"""
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
redis_client.delete(cache_key)
|
||||
|
||||
@classmethod
|
||||
def get_subscription_builder(cls, endpoint_id: str) -> SubscriptionBuilder | None:
|
||||
"""
|
||||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
cache_key = cls.encode_cache_key(endpoint_id)
|
||||
subscription_cache = redis_client.get(cache_key)
|
||||
if subscription_cache:
|
||||
return SubscriptionBuilder.model_validate(json.loads(subscription_cache))
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def append_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||
"""
|
||||
Append the validation request log to Redis.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def list_request_logs(cls, endpoint_id: str) -> list[RequestLog]:
|
||||
"""
|
||||
List the request logs for a validation endpoint.
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def process_builder_validation_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Process a temporary endpoint request.
|
||||
|
||||
:param endpoint_id: The endpoint identifier
|
||||
:param request: The Flask request object
|
||||
:return: The Flask response object
|
||||
"""
|
||||
# check if validation endpoint exists
|
||||
subscription_builder = cls.get_subscription_builder(endpoint_id)
|
||||
if not subscription_builder:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller = TriggerManager.get_trigger_provider(
|
||||
subscription_builder.tenant_id, TriggerProviderID(subscription_builder.provider_id)
|
||||
)
|
||||
response = controller.dispatch(
|
||||
user_id=subscription_builder.user_id,
|
||||
request=request,
|
||||
subscription=subscription_builder.to_subscription(),
|
||||
)
|
||||
# append the request log
|
||||
cls.append_request_log(endpoint_id, request, response.response)
|
||||
return response.response
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import logging
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerSubscriptionValidationService:
|
||||
__VALIDATION_REQUEST_CACHE_COUNT__ = 10
|
||||
__VALIDATION_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
|
||||
@classmethod
|
||||
def append_validation_request_log(cls, endpoint_id: str, request: Request, response: Response) -> None:
|
||||
"""
|
||||
Append the validation request log to Redis.
|
||||
"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def process_validating_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
"""
|
||||
Process a temporary endpoint request.
|
||||
|
||||
:param endpoint_id: The endpoint identifier
|
||||
:param request: The Flask request object
|
||||
:return: The Flask response object
|
||||
"""
|
||||
# check if validation endpoint exists
|
||||
subscription_validation = TriggerProviderService.get_subscription_validation(endpoint_id)
|
||||
if not subscription_validation:
|
||||
return None
|
||||
|
||||
# response to validation endpoint
|
||||
controller = TriggerManager.get_trigger_provider(
|
||||
subscription_validation.tenant_id, TriggerProviderID(subscription_validation.provider_id)
|
||||
)
|
||||
response = controller.dispatch(
|
||||
user_id=subscription_validation.user_id,
|
||||
request=request,
|
||||
subscription=subscription_validation.to_subscription(),
|
||||
)
|
||||
# append the request log
|
||||
cls.append_validation_request_log(endpoint_id, request, response.response)
|
||||
return response.response
|
||||
|
|
@ -1,14 +1,11 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.trigger.entities.entities import TriggerEntity
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -18,34 +15,27 @@ class TriggerService:
|
|||
__TEMPORARY_ENDPOINT_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
__ENDPOINT_REQUEST_CACHE_COUNT__ = 10
|
||||
__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__ = 5 * 60 * 1000
|
||||
# Lua script for atomic write with time & count based cleanup
|
||||
__LUA_SCRIPT__ = """
|
||||
-- KEYS[1] = zset key
|
||||
-- ARGV[1] = max_count (maximum number of entries to keep)
|
||||
-- ARGV[2] = min_ts_ms (minimum timestamp to keep = now_ms - ttl_ms)
|
||||
-- ARGV[3] = now_ms (current timestamp in milliseconds)
|
||||
-- ARGV[4] = member (log entry JSON)
|
||||
|
||||
local key = KEYS[1]
|
||||
local maxCount = tonumber(ARGV[1])
|
||||
local minTs = tonumber(ARGV[2])
|
||||
local nowMs = tonumber(ARGV[3])
|
||||
local member = ARGV[4]
|
||||
@classmethod
|
||||
def process_triggered_workflows(cls, subscription: TriggerSubscription, trigger: TriggerEntity, request: Request) -> None:
|
||||
"""Process triggered workflows."""
|
||||
|
||||
|
||||
-- 1) Add new entry with timestamp as score
|
||||
redis.call('ZADD', key, nowMs, member)
|
||||
|
||||
-- 2) Remove entries older than minTs (time-based cleanup)
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', minTs)
|
||||
|
||||
-- 3) Remove oldest entries if count exceeds maxCount (count-based cleanup)
|
||||
local n = redis.call('ZCARD', key)
|
||||
if n > maxCount then
|
||||
redis.call('ZREMRANGEBYRANK', key, 0, n - maxCount - 1) -- 0 is oldest
|
||||
end
|
||||
|
||||
return n
|
||||
"""
|
||||
@classmethod
|
||||
def select_triggers(cls, controller, dispatch_response, provider_id, subscription) -> list[TriggerEntity]:
|
||||
triggers = []
|
||||
for trigger_name in dispatch_response.triggers:
|
||||
trigger = controller.get_trigger(trigger_name)
|
||||
if trigger is None:
|
||||
logger.error(
|
||||
"Trigger '%s' not found in provider '%s' for tenant '%s'",
|
||||
trigger_name,
|
||||
provider_id,
|
||||
subscription.tenant_id,
|
||||
)
|
||||
raise ValueError(f"Trigger '{trigger_name}' not found")
|
||||
triggers.append(trigger)
|
||||
return triggers
|
||||
|
||||
@classmethod
|
||||
def process_endpoint(cls, endpoint_id: str, request: Request) -> Response | None:
|
||||
|
|
@ -53,140 +43,23 @@ class TriggerService:
|
|||
subscription = TriggerProviderService.get_subscription_by_endpoint(endpoint_id)
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
controller = TriggerManager.get_trigger_provider(subscription.tenant_id, provider_id)
|
||||
if not controller:
|
||||
return None
|
||||
|
||||
|
||||
dispatch_response = controller.dispatch(
|
||||
user_id=subscription.user_id, request=request, subscription=subscription.to_entity()
|
||||
)
|
||||
|
||||
# TODO invoke triggers
|
||||
# dispatch_response.triggers
|
||||
|
||||
if dispatch_response.triggers:
|
||||
triggers = cls.select_triggers(controller, dispatch_response, provider_id, subscription)
|
||||
for trigger in triggers:
|
||||
cls.process_triggered_workflows(
|
||||
subscription=subscription,
|
||||
trigger=trigger,
|
||||
request=request,
|
||||
)
|
||||
return dispatch_response.response
|
||||
|
||||
@classmethod
|
||||
def log_endpoint_request(cls, endpoint_id: str, request: Request) -> int:
|
||||
"""
|
||||
Log the endpoint request to Redis using ZSET for rolling log with time & count based retention.
|
||||
|
||||
Args:
|
||||
endpoint_id: The endpoint identifier
|
||||
request: The Flask request object
|
||||
|
||||
Returns:
|
||||
The current number of logged requests for this endpoint
|
||||
"""
|
||||
try:
|
||||
# Prepare timestamp
|
||||
now_ms = int(time.time() * 1000)
|
||||
min_ts = now_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
|
||||
|
||||
# Extract request data
|
||||
request_data = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"timestamp": now_ms,
|
||||
"method": request.method,
|
||||
"path": request.path,
|
||||
"headers": dict(request.headers),
|
||||
"query_params": request.args.to_dict(flat=False) if request.args else {},
|
||||
"body": None,
|
||||
"remote_addr": request.remote_addr,
|
||||
}
|
||||
|
||||
# Try to get request body if it exists
|
||||
if request.is_json:
|
||||
try:
|
||||
request_data["body"] = request.get_json(force=True)
|
||||
except Exception:
|
||||
request_data["body"] = request.get_data(as_text=True)
|
||||
elif request.data:
|
||||
request_data["body"] = request.get_data(as_text=True)
|
||||
|
||||
# Serialize to JSON
|
||||
member = json.dumps(request_data, separators=(",", ":"))
|
||||
|
||||
# Execute Lua script atomically
|
||||
key = f"trigger:endpoint_requests:{endpoint_id}"
|
||||
count = redis_client.eval(
|
||||
cls.__LUA_SCRIPT__,
|
||||
1, # number of keys
|
||||
key, # KEYS[1]
|
||||
str(cls.__ENDPOINT_REQUEST_CACHE_COUNT__), # ARGV[1] - max count
|
||||
str(min_ts), # ARGV[2] - minimum timestamp
|
||||
str(now_ms), # ARGV[3] - current timestamp
|
||||
member, # ARGV[4] - log entry
|
||||
)
|
||||
|
||||
logger.debug("Logged request for endpoint %s, current count: %s", endpoint_id, count)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to log endpoint request for %s", endpoint_id, exc_info=e)
|
||||
# Don't fail the main request processing if logging fails
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def get_recent_endpoint_requests(
|
||||
cls, endpoint_id: str, limit: int = 100, start_time_ms: Optional[int] = None, end_time_ms: Optional[int] = None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve recent logged requests for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_id: The endpoint identifier
|
||||
limit: Maximum number of entries to return
|
||||
start_time_ms: Start timestamp in milliseconds (optional)
|
||||
end_time_ms: End timestamp in milliseconds (optional, defaults to now)
|
||||
|
||||
Returns:
|
||||
List of request log entries, newest first
|
||||
"""
|
||||
try:
|
||||
key = f"trigger:endpoint_requests:{endpoint_id}"
|
||||
|
||||
# Set time bounds
|
||||
if end_time_ms is None:
|
||||
end_time_ms = int(time.time() * 1000)
|
||||
if start_time_ms is None:
|
||||
start_time_ms = end_time_ms - cls.__ENDPOINT_REQUEST_CACHE_EXPIRE_MS__
|
||||
|
||||
# Get entries in reverse order (newest first)
|
||||
entries = redis_client.zrevrangebyscore(key, max=end_time_ms, min=start_time_ms, start=0, num=limit)
|
||||
|
||||
# Parse JSON entries
|
||||
requests = []
|
||||
for entry in entries:
|
||||
try:
|
||||
requests.append(json.loads(entry))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse log entry: %s", entry)
|
||||
|
||||
return requests
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to retrieve endpoint requests for %s", endpoint_id, exc_info=e)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def clear_endpoint_requests(cls, endpoint_id: str) -> bool:
|
||||
"""
|
||||
Clear all logged requests for an endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_id: The endpoint identifier
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
key = f"trigger:endpoint_requests:{endpoint_id}"
|
||||
redis_client.delete(key)
|
||||
logger.info("Cleared request logs for endpoint %s", endpoint_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception("Failed to clear endpoint requests for %s", endpoint_id, exc_info=e)
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -420,7 +420,7 @@ class TestFileUploads:
|
|||
|
||||
assert b"POST /api/upload HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b"Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"" in raw_data
|
||||
assert b'Content-Disposition: form-data; name="file"; filename="test.txt"' in raw_data
|
||||
assert text_content.encode() in raw_data
|
||||
|
||||
def test_deserialize_request_with_text_file_upload(self):
|
||||
|
|
@ -465,7 +465,7 @@ class TestFileUploads:
|
|||
boundary = "----BoundaryString123"
|
||||
# Simulate a small PNG file header
|
||||
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10"
|
||||
|
||||
|
||||
# Build multipart body
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
|
|
@ -478,7 +478,7 @@ class TestFileUploads:
|
|||
body_parts.append(b"")
|
||||
body_parts.append(b"Test image")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
|
|
@ -499,16 +499,16 @@ class TestFileUploads:
|
|||
|
||||
assert b"POST /api/images HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b"filename=\"test.png\"" in raw_data
|
||||
assert b'filename="test.png"' in raw_data
|
||||
assert b"Content-Type: image/png" in raw_data
|
||||
assert binary_content in raw_data
|
||||
|
||||
def test_deserialize_request_with_binary_file_upload(self):
|
||||
# Test deserializing multipart/form-data request with binary file
|
||||
boundary = "----BoundaryABC123"
|
||||
# Simulate a small JPEG file header
|
||||
# Simulate a small JPEG file header
|
||||
binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00"
|
||||
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"')
|
||||
|
|
@ -520,7 +520,7 @@ class TestFileUploads:
|
|||
body_parts.append(b"")
|
||||
body_parts.append(b"Vacation 2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
raw_data = (
|
||||
|
|
@ -538,7 +538,7 @@ class TestFileUploads:
|
|||
assert request.path == "/api/photos"
|
||||
assert "multipart/form-data" in request.content_type
|
||||
assert request.headers.get("Accept") == "application/json"
|
||||
|
||||
|
||||
# Verify the binary content is preserved
|
||||
request_body = request.get_data()
|
||||
assert b"photo.jpg" in request_body
|
||||
|
|
@ -553,7 +553,7 @@ class TestFileUploads:
|
|||
boundary = "----MultiFilesBoundary"
|
||||
text_file = b"Text file contents"
|
||||
binary_file = b"\x00\x01\x02\x03\x04\x05"
|
||||
|
||||
|
||||
body_parts = []
|
||||
# First file (text)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
|
|
@ -573,7 +573,7 @@ class TestFileUploads:
|
|||
body_parts.append(b"")
|
||||
body_parts.append(b"uploads/2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
|
|
@ -606,7 +606,7 @@ class TestFileUploads:
|
|||
|
||||
boundary = "----RoundTripBoundary"
|
||||
file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80"
|
||||
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"')
|
||||
|
|
@ -618,7 +618,7 @@ class TestFileUploads:
|
|||
body_parts.append(b"")
|
||||
body_parts.append(b'{"encoding": "utf-8", "size": 42}')
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
|
|
@ -647,7 +647,7 @@ class TestFileUploads:
|
|||
assert restored_request.query_string == b"version=2"
|
||||
assert "multipart/form-data" in restored_request.content_type
|
||||
assert boundary in restored_request.content_type
|
||||
|
||||
|
||||
# Verify file content is preserved
|
||||
restored_body = restored_request.get_data()
|
||||
assert b"emoji.txt" in restored_body
|
||||
|
|
|
|||
Loading…
Reference in New Issue