mirror of https://github.com/langgenius/dify.git
feat(trigger): add plugin trigger workflow support and refactor trigger system
- Add new workflow plugin trigger service for managing plugin-based triggers - Implement trigger provider encryption utilities for secure credential storage - Add custom trigger errors module for better error handling - Refactor trigger provider and manager classes for improved plugin integration - Update API endpoints to support plugin trigger workflows - Add database migration for plugin trigger workflow support 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
cc84a45244
commit
a62d7aa3ee
|
|
@ -18,7 +18,7 @@ from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from models.workflow import WorkflowPluginTrigger
|
||||
from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService
|
||||
|
||||
|
||||
class PluginTriggerApi(Resource):
|
||||
|
|
@ -34,54 +34,21 @@ class PluginTriggerApi(Resource):
|
|||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument("provider_id", type=str, required=True, help="Provider ID is required")
|
||||
parser.add_argument("trigger_name", type=str, required=True, help="Trigger name is required")
|
||||
parser.add_argument(
|
||||
"triggered_by",
|
||||
type=str,
|
||||
required=False,
|
||||
default="production",
|
||||
choices=["debugger", "production"],
|
||||
help="triggered_by must be debugger or production",
|
||||
)
|
||||
parser.add_argument("subscription_id", type=str, required=True, help="Subscription ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
node_id = args["node_id"]
|
||||
provider_id = args["provider_id"]
|
||||
trigger_name = args["trigger_name"]
|
||||
triggered_by = args["triggered_by"]
|
||||
|
||||
# Create trigger_id from provider_id and trigger_name
|
||||
trigger_id = f"{provider_id}:{trigger_name}"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Check if plugin trigger already exists for this app, node, and environment
|
||||
existing_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_model.id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
WorkflowPluginTrigger.triggered_by == triggered_by,
|
||||
)
|
||||
)
|
||||
|
||||
if existing_trigger:
|
||||
raise BadRequest("Plugin trigger already exists for this node and environment")
|
||||
|
||||
# Create new plugin trigger
|
||||
plugin_trigger = WorkflowPluginTrigger(
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger_id=trigger_id,
|
||||
triggered_by=triggered_by,
|
||||
)
|
||||
|
||||
session.add(plugin_trigger)
|
||||
session.commit()
|
||||
session.refresh(plugin_trigger)
|
||||
plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
node_id=args["node_id"],
|
||||
provider_id=args["provider_id"],
|
||||
trigger_name=args["trigger_name"],
|
||||
subscription_id=args["subscription_id"],
|
||||
)
|
||||
|
||||
return plugin_trigger
|
||||
|
||||
|
|
@ -93,33 +60,14 @@ class PluginTriggerApi(Resource):
|
|||
"""Get plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument(
|
||||
"triggered_by",
|
||||
type=str,
|
||||
required=False,
|
||||
default="production",
|
||||
choices=["debugger", "production"],
|
||||
help="triggered_by must be debugger or production",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = args["node_id"]
|
||||
triggered_by = args["triggered_by"]
|
||||
plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_model.id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
WorkflowPluginTrigger.triggered_by == triggered_by,
|
||||
WorkflowPluginTrigger.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
return plugin_trigger
|
||||
return plugin_trigger
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -131,51 +79,22 @@ class PluginTriggerApi(Resource):
|
|||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument("provider_id", type=str, required=False, help="Provider ID")
|
||||
parser.add_argument("trigger_name", type=str, required=False, help="Trigger name")
|
||||
parser.add_argument(
|
||||
"triggered_by",
|
||||
type=str,
|
||||
required=False,
|
||||
default="production",
|
||||
choices=["debugger", "production"],
|
||||
help="triggered_by must be debugger or production",
|
||||
)
|
||||
parser.add_argument("subscription_id", type=str, required=False, help="Subscription ID")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
node_id = args["node_id"]
|
||||
triggered_by = args["triggered_by"]
|
||||
plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
provider_id=args.get("provider_id"),
|
||||
trigger_name=args.get("trigger_name"),
|
||||
subscription_id=args.get("subscription_id"),
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_model.id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
WorkflowPluginTrigger.triggered_by == triggered_by,
|
||||
WorkflowPluginTrigger.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
# Update fields if provided
|
||||
if args.get("provider_id"):
|
||||
plugin_trigger.provider_id = args["provider_id"]
|
||||
|
||||
if args.get("trigger_name"):
|
||||
# Update trigger_id if provider_id or trigger_name changed
|
||||
provider_id = args.get("provider_id") or plugin_trigger.provider_id
|
||||
trigger_name = args["trigger_name"]
|
||||
plugin_trigger.trigger_id = f"{provider_id}:{trigger_name}"
|
||||
|
||||
session.commit()
|
||||
session.refresh(plugin_trigger)
|
||||
|
||||
return plugin_trigger
|
||||
return plugin_trigger
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -185,39 +104,16 @@ class PluginTriggerApi(Resource):
|
|||
"""Delete plugin trigger"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
parser.add_argument(
|
||||
"triggered_by",
|
||||
type=str,
|
||||
required=False,
|
||||
default="production",
|
||||
choices=["debugger", "production"],
|
||||
help="triggered_by must be debugger or production",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
node_id = args["node_id"]
|
||||
triggered_by = args["triggered_by"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_model.id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
WorkflowPluginTrigger.triggered_by == triggered_by,
|
||||
WorkflowPluginTrigger.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
session.delete(plugin_trigger)
|
||||
session.commit()
|
||||
WorkflowPluginTriggerService.delete_plugin_trigger(
|
||||
app_id=app_model.id,
|
||||
node_id=args["node_id"],
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
|
|
|||
|
|
@ -117,6 +117,43 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
|||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -216,9 +253,26 @@ class TriggerOAuthAuthorizeApi(Resource):
|
|||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credentials={},
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
credential_expires_at=0,
|
||||
expires_at=0,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
|
|
@ -410,6 +464,10 @@ api.add_resource(
|
|||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f
|
|||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route(
|
||||
"/trigger/endpoint/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
|
||||
)
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
|
|
|
|||
|
|
@ -254,11 +254,13 @@ class TriggerSubscriptionResponse(BaseModel):
|
|||
|
||||
|
||||
class TriggerValidateProviderCredentialsResponse(BaseModel):
|
||||
valid: bool
|
||||
message: str
|
||||
error: str
|
||||
result: bool
|
||||
|
||||
|
||||
class TriggerDispatchResponse:
|
||||
triggers: list[str]
|
||||
response: Response
|
||||
|
||||
def __init__(self, triggers: list[str], response: Response):
|
||||
self.triggers = triggers
|
||||
self.response = response
|
||||
|
|
|
|||
|
|
@ -42,11 +42,10 @@ class PluginTriggerManager(BasePluginClient):
|
|||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
provider.declaration.identity.name = str(provider.provider)
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in provider.declaration.triggers:
|
||||
trigger.identity.provider = provider.declaration.identity.name
|
||||
trigger.identity.provider = str(provider.provider)
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -59,7 +58,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
data = json_response.get("data")
|
||||
if data:
|
||||
for trigger in data.get("declaration", {}).get("triggers", []):
|
||||
trigger["identity"]["provider"] = provider_id.provider_name
|
||||
trigger["identity"]["provider"] = str(provider_id)
|
||||
|
||||
return json_response
|
||||
|
||||
|
|
@ -71,11 +70,11 @@ class PluginTriggerManager(BasePluginClient):
|
|||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
response.declaration.identity.name = str(provider_id)
|
||||
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in response.declaration.triggers:
|
||||
trigger.identity.provider = response.declaration.identity.name
|
||||
trigger.identity.provider = str(provider_id)
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -123,7 +122,7 @@ class PluginTriggerManager(BasePluginClient):
|
|||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str]
|
||||
) -> TriggerValidateProviderCredentialsResponse:
|
||||
) -> bool:
|
||||
"""
|
||||
Validate the credentials of the trigger provider.
|
||||
"""
|
||||
|
|
@ -147,9 +146,9 @@ class PluginTriggerManager(BasePluginClient):
|
|||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
return resp.result
|
||||
|
||||
return TriggerValidateProviderCredentialsResponse(valid=False, message="No response", error="No response")
|
||||
raise ValueError("No response received from plugin daemon for validate provider credentials")
|
||||
|
||||
def dispatch_event(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -43,4 +43,15 @@ class TriggerApiEntity(BaseModel):
|
|||
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||
|
||||
|
||||
class SubscriptionBuilderApiEntity(BaseModel):
|
||||
id: str = Field(description="The id of the subscription builder")
|
||||
name: str = Field(description="The name of the subscription builder")
|
||||
provider: str = Field(description="The provider id of the subscription builder")
|
||||
endpoint: str = Field(description="The endpoint id of the subscription builder")
|
||||
parameters: Mapping[str, Any] = Field(description="The parameters of the subscription builder")
|
||||
properties: Mapping[str, Any] = Field(description="The properties of the subscription builder")
|
||||
credentials: Mapping[str, str] = Field(description="The credentials of the subscription builder")
|
||||
credential_type: CredentialType = Field(description="The credential type of the subscription builder")
|
||||
|
||||
|
||||
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
class TriggerProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
|
@ -14,7 +14,6 @@ from core.plugin.entities.plugin_daemon import CredentialType
|
|||
from core.plugin.entities.request import (
|
||||
TriggerDispatchResponse,
|
||||
TriggerInvokeResponse,
|
||||
TriggerValidateProviderCredentialsResponse,
|
||||
)
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity
|
||||
|
|
@ -27,6 +26,7 @@ from core.trigger.entities.entities import (
|
|||
TriggerProviderIdentity,
|
||||
Unsubscription,
|
||||
)
|
||||
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -41,6 +41,7 @@ class PluginTriggerProviderController:
|
|||
entity: TriggerProviderEntity,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
provider_id: TriggerProviderID,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
|
|
@ -49,18 +50,20 @@ class PluginTriggerProviderController:
|
|||
:param entity: Trigger provider entity
|
||||
:param plugin_id: Plugin ID
|
||||
:param plugin_unique_identifier: Plugin unique identifier
|
||||
:param provider_id: Provider ID
|
||||
:param tenant_id: Tenant ID
|
||||
"""
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.provider_id = provider_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_provider_id(self) -> TriggerProviderID:
|
||||
"""
|
||||
Get provider ID
|
||||
"""
|
||||
return TriggerProviderID(f"{self.plugin_id}/{self.entity.identity.name}")
|
||||
return self.provider_id
|
||||
|
||||
def to_api_entity(self) -> TriggerProviderApiEntity:
|
||||
"""
|
||||
|
|
@ -101,9 +104,7 @@ class PluginTriggerProviderController:
|
|||
"""
|
||||
return self.entity.subscription_schema
|
||||
|
||||
def validate_credentials(
|
||||
self, user_id: str, credentials: Mapping[str, str]
|
||||
) -> TriggerValidateProviderCredentialsResponse:
|
||||
def validate_credentials(self, user_id: str, credentials: Mapping[str, str]) -> None:
|
||||
"""
|
||||
Validate credentials against schema
|
||||
|
||||
|
|
@ -113,21 +114,21 @@ class PluginTriggerProviderController:
|
|||
# First validate against schema
|
||||
for config in self.entity.credentials_schema:
|
||||
if config.required and config.name not in credentials:
|
||||
return TriggerValidateProviderCredentialsResponse(
|
||||
valid=False,
|
||||
message=f"Missing required credential field: {config.name}",
|
||||
error=f"Missing required credential field: {config.name}",
|
||||
)
|
||||
raise TriggerProviderCredentialValidationError(f"Missing required credential field: {config.name}")
|
||||
|
||||
# Then validate with the plugin daemon
|
||||
manager = PluginTriggerManager()
|
||||
provider_id = self.get_provider_id()
|
||||
return manager.validate_provider_credentials(
|
||||
response = manager.validate_provider_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=str(provider_id),
|
||||
credentials=credentials,
|
||||
)
|
||||
if not response:
|
||||
raise TriggerProviderCredentialValidationError(
|
||||
"Invalid credentials",
|
||||
)
|
||||
|
||||
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
|
|
@ -154,6 +155,8 @@ class PluginTriggerProviderController:
|
|||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class TriggerManager:
|
|||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
provider_id=TriggerProviderID(provider.provider),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
|
@ -75,6 +76,7 @@ class TriggerManager:
|
|||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -115,26 +117,6 @@ class TriggerManager:
|
|||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
|
||||
|
||||
@classmethod
|
||||
def validate_trigger_credentials(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, user_id: str, credentials: Mapping[str, str]
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate trigger provider credentials
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider ID
|
||||
:param user_id: User ID
|
||||
:param credentials: Credentials to validate
|
||||
:return: Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
validation_result = provider.validate_credentials(user_id, credentials)
|
||||
return validation_result.valid, validation_result.message if not validation_result.valid else ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def invoke_trigger(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Union
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig, ProviderConfig
|
||||
from core.helper.provider_cache import TriggerProviderCredentialsCache, TriggerProviderOAuthClientParamsCache
|
||||
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
|
@ -55,3 +57,24 @@ def create_trigger_provider_oauth_encrypter(
|
|||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def masked_credentials(
|
||||
schemas: list[ProviderConfig],
|
||||
credentials: Mapping[str, str],
|
||||
) -> Mapping[str, str]:
|
||||
masked_credentials = {}
|
||||
configs = {x.name: x.to_basic_provider_config() for x in schemas}
|
||||
for key, value in credentials.items():
|
||||
config = configs.get(key)
|
||||
if not config:
|
||||
masked_credentials[key] = value
|
||||
continue
|
||||
if config.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if len(value) <= 4:
|
||||
masked_credentials[key] = "*" * len(value)
|
||||
else:
|
||||
masked_credentials[key] = value[:2] + "*" * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
masked_credentials[key] = value
|
||||
return masked_credentials
|
||||
|
|
|
|||
|
|
@ -2,4 +2,4 @@ from configs import dify_config
|
|||
|
||||
|
||||
def parse_endpoint_id(endpoint_id: str) -> str:
|
||||
return f"{dify_config.CONSOLE_API_URL}/console/api/trigger/endpoint/{endpoint_id}"
|
||||
return f"{dify_config.CONSOLE_API_URL}/triggers/plugin/{endpoint_id}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
"""plugin_trigger_workflow
|
||||
|
||||
Revision ID: 86f068bf56fb
|
||||
Revises: 132392a2635f
|
||||
Create Date: 2025-09-04 12:12:44.661875
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '86f068bf56fb'
|
||||
down_revision = '132392a2635f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('subscription_id', sa.String(length=255), nullable=False))
|
||||
batch_op.alter_column('provider_id',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.String(length=512),
|
||||
existing_nullable=False)
|
||||
batch_op.alter_column('trigger_id',
|
||||
existing_type=sa.VARCHAR(length=510),
|
||||
type_=sa.String(length=255),
|
||||
existing_nullable=False)
|
||||
batch_op.drop_constraint(batch_op.f('uniq_plugin_node'), type_='unique')
|
||||
batch_op.drop_constraint(batch_op.f('uniq_trigger_node'), type_='unique')
|
||||
batch_op.drop_index(batch_op.f('workflow_plugin_trigger_tenant_idx'))
|
||||
batch_op.drop_index(batch_op.f('workflow_plugin_trigger_trigger_idx'))
|
||||
batch_op.create_unique_constraint('uniq_app_node_subscription', ['app_id', 'node_id'])
|
||||
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id'], unique=False)
|
||||
batch_op.drop_column('triggered_by')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('triggered_by', sa.VARCHAR(length=16), autoincrement=False, nullable=False))
|
||||
batch_op.drop_index('workflow_plugin_trigger_tenant_subscription_idx')
|
||||
batch_op.drop_constraint('uniq_app_node_subscription', type_='unique')
|
||||
batch_op.create_index(batch_op.f('workflow_plugin_trigger_trigger_idx'), ['trigger_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('workflow_plugin_trigger_tenant_idx'), ['tenant_id'], unique=False)
|
||||
batch_op.create_unique_constraint(batch_op.f('uniq_trigger_node'), ['trigger_id', 'node_id'], postgresql_nulls_not_distinct=False)
|
||||
batch_op.create_unique_constraint(batch_op.f('uniq_plugin_node'), ['app_id', 'node_id', 'triggered_by'], postgresql_nulls_not_distinct=False)
|
||||
batch_op.alter_column('trigger_id',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.VARCHAR(length=510),
|
||||
existing_nullable=False)
|
||||
batch_op.alter_column('provider_id',
|
||||
existing_type=sa.String(length=512),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=False)
|
||||
batch_op.drop_column('subscription_id')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1436,8 +1436,8 @@ class WorkflowPluginTrigger(Base):
|
|||
- node_id (varchar) Node ID which node in the workflow
|
||||
- tenant_id (uuid) Workspace ID
|
||||
- provider_id (varchar) Plugin provider ID
|
||||
- trigger_id (varchar) Unique trigger identifier (provider_id + trigger_name)
|
||||
- triggered_by (varchar) Environment: debugger or production
|
||||
- trigger_id (varchar) trigger id (github_issues_trigger)
|
||||
- subscription_id (varchar) Subscription ID
|
||||
- created_at (timestamp) Creation time
|
||||
- updated_at (timestamp) Last update time
|
||||
"""
|
||||
|
|
@ -1445,19 +1445,17 @@ class WorkflowPluginTrigger(Base):
|
|||
__tablename__ = "workflow_plugin_triggers"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_plugin_trigger_pkey"),
|
||||
sa.Index("workflow_plugin_trigger_tenant_idx", "tenant_id"),
|
||||
sa.Index("workflow_plugin_trigger_trigger_idx", "trigger_id"),
|
||||
sa.UniqueConstraint("app_id", "node_id", "triggered_by", name="uniq_plugin_node"),
|
||||
sa.UniqueConstraint("trigger_id", "node_id", name="uniq_trigger_node"),
|
||||
sa.Index("workflow_plugin_trigger_tenant_subscription_idx", "tenant_id", "subscription_id"),
|
||||
sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node_subscription"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
trigger_id: Mapped[str] = mapped_column(String(510), nullable=False) # provider_id + trigger_name
|
||||
triggered_by: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
provider_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
trigger_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
subscription_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class TriggerProviderService:
|
|||
# Check provider count limit
|
||||
provider_count = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.count()
|
||||
)
|
||||
|
||||
|
|
@ -118,7 +118,7 @@ class TriggerProviderService:
|
|||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerSubscription)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
|
|
@ -136,7 +136,7 @@ class TriggerProviderService:
|
|||
user_id=user_id,
|
||||
name=name,
|
||||
endpoint_id=endpoint_id,
|
||||
provider_id=provider_id,
|
||||
provider_id=str(provider_id),
|
||||
parameters=parameters,
|
||||
properties=properties,
|
||||
credentials=encrypter.encrypt(dict(credentials)),
|
||||
|
|
@ -447,5 +447,5 @@ class TriggerProviderService:
|
|||
Get a trigger subscription by the endpoint ID.
|
||||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first()
|
||||
subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first()
|
||||
return subscription
|
||||
|
|
|
|||
|
|
@ -2,16 +2,22 @@ import json
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
RequestLog,
|
||||
SubscriptionBuilder,
|
||||
)
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import masked_credentials
|
||||
from core.trigger.utils.endpoint import parse_endpoint_id
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
|
|
@ -43,7 +49,7 @@ class TriggerSubscriptionBuilderService:
|
|||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
) -> None:
|
||||
) -> Mapping[str, Any]:
|
||||
"""Verify a trigger subscription builder"""
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
|
|
@ -53,7 +59,17 @@ class TriggerSubscriptionBuilderService:
|
|||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
|
||||
if subscription_builder.credential_type == CredentialType.OAUTH2:
|
||||
return {"verified": bool(subscription_builder.credentials)}
|
||||
|
||||
if subscription_builder.credential_type == CredentialType.API_KEY:
|
||||
try:
|
||||
provider_controller.validate_credentials(user_id, subscription_builder.credentials)
|
||||
return {"verified": True}
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
raise ValueError(f"Invalid credentials: {e}")
|
||||
|
||||
return {"verified": True}
|
||||
|
||||
@classmethod
|
||||
def build_trigger_subscription_builder(
|
||||
|
|
@ -72,7 +88,7 @@ class TriggerSubscriptionBuilderService:
|
|||
if not subscription_builder:
|
||||
raise ValueError(f"Subscription builder {subscription_builder_id} not found")
|
||||
|
||||
if subscription_builder.name is None:
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value)
|
||||
|
|
@ -97,7 +113,7 @@ class TriggerSubscriptionBuilderService:
|
|||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
endpoint=subscription_builder.endpoint_id,
|
||||
endpoint=parse_endpoint_id(subscription_builder.endpoint_id),
|
||||
parameters=subscription_builder.parameters,
|
||||
credentials=subscription_builder.credentials,
|
||||
)
|
||||
|
|
@ -162,21 +178,57 @@ class TriggerSubscriptionBuilderService:
|
|||
def update_trigger_subscription_builder(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_builder: SubscriptionBuilder,
|
||||
) -> SubscriptionBuilder:
|
||||
provider_id: TriggerProviderID,
|
||||
subscription_builder_id: str,
|
||||
name: str | None,
|
||||
parameters: Mapping[str, Any] | None,
|
||||
properties: Mapping[str, Any] | None,
|
||||
credentials: Mapping[str, str] | None,
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
"""
|
||||
Update a trigger subscription validation.
|
||||
"""
|
||||
subscription_id = subscription_builder.id
|
||||
subscription_id = subscription_builder_id
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
if not provider_controller:
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
cache_key = cls.encode_cache_key(subscription_id)
|
||||
subscription_builder_cache = cls.get_subscription_builder(subscription_id)
|
||||
if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} not found")
|
||||
subscription_builder = cls.get_subscription_builder(subscription_id)
|
||||
if not subscription_builder or subscription_builder.tenant_id != tenant_id:
|
||||
raise ValueError(f"Subscription {subscription_id} expired or not found")
|
||||
|
||||
if name:
|
||||
subscription_builder.name = name
|
||||
if parameters:
|
||||
subscription_builder.parameters = parameters
|
||||
if properties:
|
||||
subscription_builder.properties = properties
|
||||
if credentials:
|
||||
subscription_builder.credentials = credentials
|
||||
|
||||
redis_client.setex(
|
||||
cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json()
|
||||
)
|
||||
return subscription_builder
|
||||
return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder)
|
||||
|
||||
@classmethod
|
||||
def builder_to_api_entity(
|
||||
cls, controller: PluginTriggerProviderController, entity: SubscriptionBuilder
|
||||
) -> SubscriptionBuilderApiEntity:
|
||||
return SubscriptionBuilderApiEntity(
|
||||
id=entity.id,
|
||||
name=entity.name or "",
|
||||
provider=entity.provider_id,
|
||||
endpoint=parse_endpoint_id(entity.endpoint_id),
|
||||
parameters=entity.parameters,
|
||||
properties=entity.properties,
|
||||
credential_type=CredentialType(entity.credential_type),
|
||||
credentials=masked_credentials(
|
||||
schemas=controller.get_credentials_schema(CredentialType(entity.credential_type)),
|
||||
credentials=entity.credentials,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_subscription_builder(cls, subscription_id: str) -> None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,376 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowPluginTrigger
|
||||
|
||||
|
||||
class WorkflowPluginTriggerService:
|
||||
"""Service for managing workflow plugin triggers"""
|
||||
|
||||
@classmethod
|
||||
def create_plugin_trigger(
|
||||
cls,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
provider_id: str,
|
||||
trigger_name: str,
|
||||
subscription_id: str,
|
||||
) -> WorkflowPluginTrigger:
|
||||
"""Create a new plugin trigger
|
||||
|
||||
Args:
|
||||
app_id: The app ID
|
||||
tenant_id: The tenant ID
|
||||
node_id: The node ID in the workflow
|
||||
provider_id: The plugin provider ID
|
||||
trigger_name: The trigger name
|
||||
subscription_id: The subscription ID
|
||||
|
||||
Returns:
|
||||
The created WorkflowPluginTrigger instance
|
||||
|
||||
Raises:
|
||||
BadRequest: If plugin trigger already exists for this app and node
|
||||
"""
|
||||
# Create trigger_id from provider_id and trigger_name
|
||||
trigger_id = f"{provider_id}:{trigger_name}"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Check if plugin trigger already exists for this app and node
|
||||
# Based on unique constraint: uniq_app_node
|
||||
existing_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
)
|
||||
)
|
||||
|
||||
if existing_trigger:
|
||||
raise BadRequest("Plugin trigger already exists for this app and node")
|
||||
|
||||
# Create new plugin trigger
|
||||
plugin_trigger = WorkflowPluginTrigger(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
trigger_id=trigger_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
|
||||
session.add(plugin_trigger)
|
||||
session.commit()
|
||||
session.refresh(plugin_trigger)
|
||||
|
||||
return plugin_trigger
|
||||
|
||||
@classmethod
|
||||
def get_plugin_trigger(
|
||||
cls,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowPluginTrigger:
|
||||
"""Get a plugin trigger by app_id and node_id
|
||||
|
||||
Args:
|
||||
app_id: The app ID
|
||||
node_id: The node ID in the workflow
|
||||
|
||||
Returns:
|
||||
The WorkflowPluginTrigger instance
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger using unique constraint
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
return plugin_trigger
|
||||
|
||||
@classmethod
|
||||
def get_plugin_trigger_by_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
) -> WorkflowPluginTrigger:
|
||||
"""Get a plugin trigger by tenant_id and subscription_id
|
||||
This is the primary query pattern, optimized with composite index
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
subscription_id: The subscription ID
|
||||
|
||||
Returns:
|
||||
The WorkflowPluginTrigger instance
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger using indexed columns
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
return plugin_trigger
|
||||
|
||||
@classmethod
|
||||
def list_plugin_triggers_by_tenant(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
) -> list[WorkflowPluginTrigger]:
|
||||
"""List all plugin triggers for a tenant
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
|
||||
Returns:
|
||||
List of WorkflowPluginTrigger instances
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
plugin_triggers = session.scalars(
|
||||
select(WorkflowPluginTrigger)
|
||||
.where(WorkflowPluginTrigger.tenant_id == tenant_id)
|
||||
.order_by(WorkflowPluginTrigger.created_at.desc())
|
||||
).all()
|
||||
|
||||
return list(plugin_triggers)
|
||||
|
||||
@classmethod
|
||||
def list_plugin_triggers_by_subscription(
|
||||
cls,
|
||||
subscription_id: str,
|
||||
) -> list[WorkflowPluginTrigger]:
|
||||
"""List all plugin triggers for a subscription
|
||||
|
||||
Args:
|
||||
subscription_id: The subscription ID
|
||||
|
||||
Returns:
|
||||
List of WorkflowPluginTrigger instances
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
plugin_triggers = session.scalars(
|
||||
select(WorkflowPluginTrigger)
|
||||
.where(WorkflowPluginTrigger.subscription_id == subscription_id)
|
||||
.order_by(WorkflowPluginTrigger.created_at.desc())
|
||||
).all()
|
||||
|
||||
return list(plugin_triggers)
|
||||
|
||||
@classmethod
|
||||
def update_plugin_trigger(
|
||||
cls,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
provider_id: Optional[str] = None,
|
||||
trigger_name: Optional[str] = None,
|
||||
subscription_id: Optional[str] = None,
|
||||
) -> WorkflowPluginTrigger:
|
||||
"""Update a plugin trigger
|
||||
|
||||
Args:
|
||||
app_id: The app ID
|
||||
node_id: The node ID in the workflow
|
||||
provider_id: The new provider ID (optional)
|
||||
trigger_name: The new trigger name (optional)
|
||||
subscription_id: The new subscription ID (optional)
|
||||
|
||||
Returns:
|
||||
The updated WorkflowPluginTrigger instance
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger using unique constraint
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
# Update fields if provided
|
||||
if provider_id:
|
||||
plugin_trigger.provider_id = provider_id
|
||||
|
||||
if trigger_name:
|
||||
# Update trigger_id if provider_id or trigger_name changed
|
||||
provider_id = provider_id or plugin_trigger.provider_id
|
||||
plugin_trigger.trigger_id = f"{provider_id}:{trigger_name}"
|
||||
|
||||
if subscription_id:
|
||||
plugin_trigger.subscription_id = subscription_id
|
||||
|
||||
session.commit()
|
||||
session.refresh(plugin_trigger)
|
||||
|
||||
return plugin_trigger
|
||||
|
||||
@classmethod
|
||||
def update_plugin_trigger_by_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
provider_id: Optional[str] = None,
|
||||
trigger_name: Optional[str] = None,
|
||||
new_subscription_id: Optional[str] = None,
|
||||
) -> WorkflowPluginTrigger:
|
||||
"""Update a plugin trigger by tenant_id and subscription_id
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
subscription_id: The current subscription ID
|
||||
provider_id: The new provider ID (optional)
|
||||
trigger_name: The new trigger name (optional)
|
||||
new_subscription_id: The new subscription ID (optional)
|
||||
|
||||
Returns:
|
||||
The updated WorkflowPluginTrigger instance
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger using indexed columns
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
# Update fields if provided
|
||||
if provider_id:
|
||||
plugin_trigger.provider_id = provider_id
|
||||
|
||||
if trigger_name:
|
||||
# Update trigger_id if provider_id or trigger_name changed
|
||||
provider_id = provider_id or plugin_trigger.provider_id
|
||||
plugin_trigger.trigger_id = f"{provider_id}:{trigger_name}"
|
||||
|
||||
if new_subscription_id:
|
||||
plugin_trigger.subscription_id = new_subscription_id
|
||||
|
||||
session.commit()
|
||||
session.refresh(plugin_trigger)
|
||||
|
||||
return plugin_trigger
|
||||
|
||||
@classmethod
|
||||
def delete_plugin_trigger(
|
||||
cls,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
) -> None:
|
||||
"""Delete a plugin trigger by app_id and node_id
|
||||
|
||||
Args:
|
||||
app_id: The app ID
|
||||
node_id: The node ID in the workflow
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger using unique constraint
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.app_id == app_id,
|
||||
WorkflowPluginTrigger.node_id == node_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
session.delete(plugin_trigger)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def delete_plugin_trigger_by_subscription(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
) -> None:
|
||||
"""Delete a plugin trigger by tenant_id and subscription_id
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
subscription_id: The subscription ID
|
||||
|
||||
Raises:
|
||||
NotFound: If plugin trigger not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Find plugin trigger using indexed columns
|
||||
plugin_trigger = session.scalar(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.tenant_id == tenant_id,
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
)
|
||||
)
|
||||
|
||||
if not plugin_trigger:
|
||||
raise NotFound("Plugin trigger not found")
|
||||
|
||||
session.delete(plugin_trigger)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def delete_all_by_subscription(
|
||||
cls,
|
||||
subscription_id: str,
|
||||
) -> int:
|
||||
"""Delete all plugin triggers for a subscription
|
||||
Useful when a subscription is cancelled
|
||||
|
||||
Args:
|
||||
subscription_id: The subscription ID
|
||||
|
||||
Returns:
|
||||
Number of triggers deleted
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Find all plugin triggers for this subscription
|
||||
plugin_triggers = session.scalars(
|
||||
select(WorkflowPluginTrigger).where(
|
||||
WorkflowPluginTrigger.subscription_id == subscription_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
count = len(plugin_triggers)
|
||||
|
||||
for trigger in plugin_triggers:
|
||||
session.delete(trigger)
|
||||
|
||||
session.commit()
|
||||
|
||||
return count
|
||||
Loading…
Reference in New Issue