From a62d7aa3eeffca5080a441da3a3458c61b81d9cf Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 4 Sep 2025 12:47:51 +0800 Subject: [PATCH] feat(trigger): add plugin trigger workflow support and refactor trigger system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add new workflow plugin trigger service for managing plugin-based triggers - Implement trigger provider encryption utilities for secure credential storage - Add custom trigger errors module for better error handling - Refactor trigger provider and manager classes for improved plugin integration - Update API endpoints to support plugin trigger workflows - Add database migration for plugin trigger workflow support 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../console/app/workflow_trigger.py | 160 ++------ .../console/workspace/trigger_providers.py | 60 ++- api/controllers/trigger/trigger.py | 4 +- api/core/plugin/entities/request.py | 8 +- api/core/plugin/impl/trigger.py | 17 +- api/core/trigger/entities/api_entities.py | 11 + api/core/trigger/errors.py | 2 + api/core/trigger/provider.py | 25 +- api/core/trigger/trigger_manager.py | 22 +- api/core/trigger/utils/encryption.py | 23 ++ api/core/trigger/utils/endpoint.py | 2 +- ...12-86f068bf56fb_plugin_trigger_workflow.py | 62 +++ api/models/workflow.py | 16 +- .../trigger/trigger_provider_service.py | 8 +- .../trigger_subscription_builder_service.py | 74 +++- .../workflow_plugin_trigger_service.py | 376 ++++++++++++++++++ 16 files changed, 666 insertions(+), 204 deletions(-) create mode 100644 api/core/trigger/errors.py create mode 100644 api/migrations/versions/2025_09_04_1212-86f068bf56fb_plugin_trigger_workflow.py create mode 100644 api/services/workflow_plugin_trigger_service.py diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 30c8a8a8ab..58cbf51ea8 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -18,7 +18,7 @@ from models.workflow import AppTrigger, AppTriggerStatus, WorkflowWebhookTrigger logger = logging.getLogger(__name__) -from models.workflow import WorkflowPluginTrigger +from services.workflow_plugin_trigger_service import WorkflowPluginTriggerService class PluginTriggerApi(Resource): @@ -34,54 +34,21 @@ class PluginTriggerApi(Resource): parser.add_argument("node_id", type=str, required=True, help="Node ID is required") parser.add_argument("provider_id", type=str, required=True, help="Provider ID is required") parser.add_argument("trigger_name", type=str, required=True, help="Trigger name is required") - parser.add_argument( - "triggered_by", - type=str, - required=False, - default="production", - choices=["debugger", "production"], - help="triggered_by must be debugger or production", - ) + parser.add_argument("subscription_id", type=str, required=True, help="Subscription ID is required") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - node_id = args["node_id"] - provider_id = args["provider_id"] - trigger_name = args["trigger_name"] - triggered_by = args["triggered_by"] - - # Create trigger_id from provider_id and trigger_name - trigger_id = f"{provider_id}:{trigger_name}" - - with Session(db.engine) as session: - # Check if plugin trigger already exists for this app, node, and environment - existing_trigger = session.scalar( - select(WorkflowPluginTrigger).where( - WorkflowPluginTrigger.app_id == app_model.id, - WorkflowPluginTrigger.node_id == node_id, - WorkflowPluginTrigger.triggered_by == triggered_by, - ) - ) - - if existing_trigger: - raise BadRequest("Plugin trigger already exists for this node and environment") - - # Create new plugin trigger - plugin_trigger = WorkflowPluginTrigger( - app_id=app_model.id, - node_id=node_id, - tenant_id=current_user.current_tenant_id, - provider_id=provider_id, - trigger_id=trigger_id, - triggered_by=triggered_by, - ) - - session.add(plugin_trigger) - session.commit() - session.refresh(plugin_trigger) + plugin_trigger = WorkflowPluginTriggerService.create_plugin_trigger( + app_id=app_model.id, + tenant_id=current_user.current_tenant_id, + node_id=args["node_id"], + provider_id=args["provider_id"], + trigger_name=args["trigger_name"], + subscription_id=args["subscription_id"], + ) return plugin_trigger @@ -93,33 +60,14 @@ class PluginTriggerApi(Resource): """Get plugin trigger""" parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, help="Node ID is required") - parser.add_argument( - "triggered_by", - type=str, - required=False, - default="production", - choices=["debugger", "production"], - help="triggered_by must be debugger or production", - ) args = parser.parse_args() - node_id = args["node_id"] - triggered_by = args["triggered_by"] + plugin_trigger = WorkflowPluginTriggerService.get_plugin_trigger( + app_id=app_model.id, + node_id=args["node_id"], + ) - with Session(db.engine) as session: - # Find plugin trigger - plugin_trigger = session.scalar( - select(WorkflowPluginTrigger).where( - WorkflowPluginTrigger.app_id == app_model.id, - WorkflowPluginTrigger.node_id == node_id, - WorkflowPluginTrigger.triggered_by == triggered_by, - WorkflowPluginTrigger.tenant_id == current_user.current_tenant_id, - ) - ) - - if not plugin_trigger: - raise NotFound("Plugin trigger not found") - return plugin_trigger + return plugin_trigger @setup_required @login_required @@ -131,51 +79,22 @@ class PluginTriggerApi(Resource): parser.add_argument("node_id", type=str, required=True, help="Node ID is required") parser.add_argument("provider_id", type=str, required=False, help="Provider ID") parser.add_argument("trigger_name", type=str, required=False, help="Trigger name") - parser.add_argument( - "triggered_by", - type=str, - required=False, - default="production", - choices=["debugger", "production"], - help="triggered_by must be debugger or production", - ) + parser.add_argument("subscription_id", type=str, required=False, help="Subscription ID") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - node_id = args["node_id"] - triggered_by = args["triggered_by"] + plugin_trigger = WorkflowPluginTriggerService.update_plugin_trigger( + app_id=app_model.id, + node_id=args["node_id"], + provider_id=args.get("provider_id"), + trigger_name=args.get("trigger_name"), + subscription_id=args.get("subscription_id"), + ) - with Session(db.engine) as session: - # Find plugin trigger - plugin_trigger = session.scalar( - select(WorkflowPluginTrigger).where( - WorkflowPluginTrigger.app_id == app_model.id, - WorkflowPluginTrigger.node_id == node_id, - WorkflowPluginTrigger.triggered_by == triggered_by, - WorkflowPluginTrigger.tenant_id == current_user.current_tenant_id, - ) - ) - - if not plugin_trigger: - raise NotFound("Plugin trigger not found") - - # Update fields if provided - if args.get("provider_id"): - plugin_trigger.provider_id = args["provider_id"] - - if args.get("trigger_name"): - # Update trigger_id if provider_id or trigger_name changed - provider_id = args.get("provider_id") or plugin_trigger.provider_id - trigger_name = args["trigger_name"] - plugin_trigger.trigger_id = f"{provider_id}:{trigger_name}" - - session.commit() - session.refresh(plugin_trigger) - - return plugin_trigger + return plugin_trigger @setup_required @login_required @@ -185,39 +104,16 @@ class PluginTriggerApi(Resource): """Delete plugin trigger""" parser = reqparse.RequestParser() parser.add_argument("node_id", type=str, required=True, help="Node ID is required") - parser.add_argument( - "triggered_by", - type=str, - required=False, - default="production", - choices=["debugger", "production"], - help="triggered_by must be debugger or production", - ) args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - node_id = args["node_id"] - triggered_by = args["triggered_by"] - - with Session(db.engine) as session: - # Find plugin trigger - plugin_trigger = session.scalar( - select(WorkflowPluginTrigger).where( - WorkflowPluginTrigger.app_id == app_model.id, - WorkflowPluginTrigger.node_id == node_id, - WorkflowPluginTrigger.triggered_by == triggered_by, - WorkflowPluginTrigger.tenant_id == current_user.current_tenant_id, - ) - ) - - if not plugin_trigger: - raise NotFound("Plugin trigger not found") - - session.delete(plugin_trigger) - session.commit() + WorkflowPluginTriggerService.delete_plugin_trigger( + app_id=app_model.id, + node_id=args["node_id"], + ) return {"result": "success"}, 204 diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 53096ff867..34e866d408 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -117,6 +117,43 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): raise +class TriggerSubscriptionBuilderUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider, subscription_builder_id): + """Update a subscription instance for a trigger provider""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + + parser = reqparse.RequestParser() + # The name of the subscription builder + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + # The parameters of the subscription builder + parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json") + # The properties of the subscription builder + parser.add_argument("properties", type=dict, required=False, nullable=True, location="json") + # The credentials of the subscription builder + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + args = parser.parse_args() + try: + return jsonable_encoder( + TriggerSubscriptionBuilderService.update_trigger_subscription_builder( + tenant_id=user.current_tenant_id, + provider_id=TriggerProviderID(provider), + subscription_builder_id=subscription_builder_id, + name=args.get("name", None), + parameters=args.get("parameters", None), + properties=args.get("properties", None), + credentials=args.get("credentials", None), + ) + ) + except Exception as e: + logger.exception("Error updating provider credential", exc_info=e) + raise + + class TriggerSubscriptionBuilderBuildApi(Resource): @setup_required @login_required @@ -216,9 +253,26 @@ class TriggerOAuthAuthorizeApi(Resource): redirect_uri=redirect_uri, system_credentials=oauth_client_params, ) + # Create subscription builder + subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( + tenant_id=tenant_id, + user_id=user.id, + provider_id=provider_id, + credentials={}, + credential_type=CredentialType.OAUTH2, + credential_expires_at=0, + expires_at=0, + ) # Create response with cookie - response = make_response(jsonable_encoder(authorization_url_response)) + response = make_response( + jsonable_encoder( + { + "authorization_url": authorization_url_response, + "subscription_builder": subscription_builder, + } + ) + ) response.set_cookie( "context_id", context_id, @@ -410,6 +464,10 @@ api.add_resource( TriggerSubscriptionBuilderCreateApi, "/workspaces/current/trigger-provider//subscriptions/builder/create", ) +api.add_resource( + TriggerSubscriptionBuilderUpdateApi, + "/workspaces/current/trigger-provider//subscriptions/builder/update/", +) api.add_resource( TriggerSubscriptionBuilderVerifyApi, "/workspaces/current/trigger-provider//subscriptions/builder/verify/", diff --git a/api/controllers/trigger/trigger.py b/api/controllers/trigger/trigger.py index 29ccd87812..b7bcfffcf6 100644 --- a/api/controllers/trigger/trigger.py +++ b/api/controllers/trigger/trigger.py @@ -14,9 +14,7 @@ UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f UUID_MATCHER = re.compile(UUID_PATTERN) -@bp.route( - "/trigger/endpoint/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] -) +@bp.route("/plugin/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) def trigger_endpoint(endpoint_id: str): """ Handle endpoint trigger calls. diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index da0f2b8419..4c271912f1 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -254,11 +254,13 @@ class TriggerSubscriptionResponse(BaseModel): class TriggerValidateProviderCredentialsResponse(BaseModel): - valid: bool - message: str - error: str + result: bool class TriggerDispatchResponse: triggers: list[str] response: Response + + def __init__(self, triggers: list[str], response: Response): + self.triggers = triggers + self.response = response diff --git a/api/core/plugin/impl/trigger.py b/api/core/plugin/impl/trigger.py index ee643c9d76..d9bc62359f 100644 --- a/api/core/plugin/impl/trigger.py +++ b/api/core/plugin/impl/trigger.py @@ -42,11 +42,10 @@ class PluginTriggerManager(BasePluginClient): ) for provider in response: - provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" - + provider.declaration.identity.name = str(provider.provider) # override the provider name for each trigger to plugin_id/provider_name for trigger in provider.declaration.triggers: - trigger.identity.provider = provider.declaration.identity.name + trigger.identity.provider = str(provider.provider) return response @@ -59,7 +58,7 @@ class PluginTriggerManager(BasePluginClient): data = json_response.get("data") if data: for trigger in data.get("declaration", {}).get("triggers", []): - trigger["identity"]["provider"] = provider_id.provider_name + trigger["identity"]["provider"] = str(provider_id) return json_response @@ -71,11 +70,11 @@ class PluginTriggerManager(BasePluginClient): transformer=transformer, ) - response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + response.declaration.identity.name = str(provider_id) # override the provider name for each trigger to plugin_id/provider_name for trigger in response.declaration.triggers: - trigger.identity.provider = response.declaration.identity.name + trigger.identity.provider = str(provider_id) return response @@ -123,7 +122,7 @@ class PluginTriggerManager(BasePluginClient): def validate_provider_credentials( self, tenant_id: str, user_id: str, provider: str, credentials: Mapping[str, str] - ) -> TriggerValidateProviderCredentialsResponse: + ) -> bool: """ Validate the credentials of the trigger provider. """ @@ -147,9 +146,9 @@ class PluginTriggerManager(BasePluginClient): ) for resp in response: - return resp + return resp.result - return TriggerValidateProviderCredentialsResponse(valid=False, message="No response", error="No response") + raise ValueError("No response received from plugin daemon for validate provider credentials") def dispatch_event( self, diff --git a/api/core/trigger/entities/api_entities.py b/api/core/trigger/entities/api_entities.py index 0820044d88..641eaa4c7a 100644 --- a/api/core/trigger/entities/api_entities.py +++ b/api/core/trigger/entities/api_entities.py @@ -43,4 +43,15 @@ class TriggerApiEntity(BaseModel): output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger") +class SubscriptionBuilderApiEntity(BaseModel): + id: str = Field(description="The id of the subscription builder") + name: str = Field(description="The name of the subscription builder") + provider: str = Field(description="The provider id of the subscription builder") + endpoint: str = Field(description="The endpoint id of the subscription builder") + parameters: Mapping[str, Any] = Field(description="The parameters of the subscription builder") + properties: Mapping[str, Any] = Field(description="The properties of the subscription builder") + credentials: Mapping[str, str] = Field(description="The credentials of the subscription builder") + credential_type: CredentialType = Field(description="The credential type of the subscription builder") + + __all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderSubscriptionApiEntity"] diff --git a/api/core/trigger/errors.py b/api/core/trigger/errors.py new file mode 100644 index 0000000000..bbc27e1eae --- /dev/null +++ b/api/core/trigger/errors.py @@ -0,0 +1,2 @@ +class TriggerProviderCredentialValidationError(ValueError): + pass diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py index 6dae530184..703e9b0d19 100644 --- a/api/core/trigger/provider.py +++ b/api/core/trigger/provider.py @@ -14,7 +14,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import ( TriggerDispatchResponse, TriggerInvokeResponse, - TriggerValidateProviderCredentialsResponse, ) from core.plugin.impl.trigger import PluginTriggerManager from core.trigger.entities.api_entities import TriggerProviderApiEntity @@ -27,6 +26,7 @@ from core.trigger.entities.entities import ( TriggerProviderIdentity, Unsubscription, ) +from core.trigger.errors import TriggerProviderCredentialValidationError logger = logging.getLogger(__name__) @@ -41,6 +41,7 @@ class PluginTriggerProviderController: entity: TriggerProviderEntity, plugin_id: str, plugin_unique_identifier: str, + provider_id: TriggerProviderID, tenant_id: str, ): """ @@ -49,18 +50,20 @@ class PluginTriggerProviderController: :param entity: Trigger provider entity :param plugin_id: Plugin ID :param plugin_unique_identifier: Plugin unique identifier + :param provider_id: Provider ID :param tenant_id: Tenant ID """ self.entity = entity self.tenant_id = tenant_id self.plugin_id = plugin_id + self.provider_id = provider_id self.plugin_unique_identifier = plugin_unique_identifier def get_provider_id(self) -> TriggerProviderID: """ Get provider ID """ - return TriggerProviderID(f"{self.plugin_id}/{self.entity.identity.name}") + return self.provider_id def to_api_entity(self) -> TriggerProviderApiEntity: """ @@ -101,9 +104,7 @@ class PluginTriggerProviderController: """ return self.entity.subscription_schema - def validate_credentials( - self, user_id: str, credentials: Mapping[str, str] - ) -> TriggerValidateProviderCredentialsResponse: + def validate_credentials(self, user_id: str, credentials: Mapping[str, str]) -> None: """ Validate credentials against schema @@ -113,21 +114,21 @@ class PluginTriggerProviderController: # First validate against schema for config in self.entity.credentials_schema: if config.required and config.name not in credentials: - return TriggerValidateProviderCredentialsResponse( - valid=False, - message=f"Missing required credential field: {config.name}", - error=f"Missing required credential field: {config.name}", - ) + raise TriggerProviderCredentialValidationError(f"Missing required credential field: {config.name}") # Then validate with the plugin daemon manager = PluginTriggerManager() provider_id = self.get_provider_id() - return manager.validate_provider_credentials( + response = manager.validate_provider_credentials( tenant_id=self.tenant_id, user_id=user_id, provider=str(provider_id), credentials=credentials, ) + if not response: + raise TriggerProviderCredentialValidationError( + "Invalid credentials", + ) def get_supported_credential_types(self) -> list[CredentialType]: """ @@ -154,6 +155,8 @@ class PluginTriggerProviderController: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] if credential_type == CredentialType.API_KEY: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + if credential_type == CredentialType.UNAUTHORIZED: + return [] raise ValueError(f"Invalid credential type: {credential_type}") def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]: diff --git a/api/core/trigger/trigger_manager.py b/api/core/trigger/trigger_manager.py index b20b1fdd79..719dd6f251 100644 --- a/api/core/trigger/trigger_manager.py +++ b/api/core/trigger/trigger_manager.py @@ -46,6 +46,7 @@ class TriggerManager: entity=provider.declaration, plugin_id=provider.plugin_id, plugin_unique_identifier=provider.plugin_unique_identifier, + provider_id=TriggerProviderID(provider.provider), tenant_id=tenant_id, ) controllers.append(controller) @@ -75,6 +76,7 @@ class TriggerManager: entity=provider.declaration, plugin_id=provider.plugin_id, plugin_unique_identifier=provider.plugin_unique_identifier, + provider_id=provider_id, tenant_id=tenant_id, ) except Exception as e: @@ -115,26 +117,6 @@ class TriggerManager: """ return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name) - @classmethod - def validate_trigger_credentials( - cls, tenant_id: str, provider_id: TriggerProviderID, user_id: str, credentials: Mapping[str, str] - ) -> tuple[bool, str]: - """ - Validate trigger provider credentials - - :param tenant_id: Tenant ID - :param provider_id: Provider ID - :param user_id: User ID - :param credentials: Credentials to validate - :return: Tuple of (is_valid, error_message) - """ - try: - provider = cls.get_trigger_provider(tenant_id, provider_id) - validation_result = provider.validate_credentials(user_id, credentials) - return validation_result.valid, validation_result.message if not validation_result.valid else "" - except Exception as e: - return False, str(e) - @classmethod def invoke_trigger( cls, diff --git a/api/core/trigger/utils/encryption.py b/api/core/trigger/utils/encryption.py index 663807ff7b..0f49343b82 100644 --- a/api/core/trigger/utils/encryption.py +++ b/api/core/trigger/utils/encryption.py @@ -1,5 +1,7 @@ +from collections.abc import Mapping from typing import Union +from core.entities.provider_entities import BasicProviderConfig, ProviderConfig from core.helper.provider_cache import TriggerProviderCredentialsCache, TriggerProviderOAuthClientParamsCache from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType @@ -55,3 +57,24 @@ def create_trigger_provider_oauth_encrypter( cache=cache, ) return encrypter, cache + + +def masked_credentials( + schemas: list[ProviderConfig], + credentials: Mapping[str, str], +) -> Mapping[str, str]: + masked_credentials = {} + configs = {x.name: x.to_basic_provider_config() for x in schemas} + for key, value in credentials.items(): + config = configs.get(key) + if not config: + masked_credentials[key] = value + continue + if config.type == BasicProviderConfig.Type.SECRET_INPUT: + if len(value) <= 4: + masked_credentials[key] = "*" * len(value) + else: + masked_credentials[key] = value[:2] + "*" * (len(value) - 4) + value[-2:] + else: + masked_credentials[key] = value + return masked_credentials diff --git a/api/core/trigger/utils/endpoint.py b/api/core/trigger/utils/endpoint.py index 242075acac..c203cdd9f3 100644 --- a/api/core/trigger/utils/endpoint.py +++ b/api/core/trigger/utils/endpoint.py @@ -2,4 +2,4 @@ from configs import dify_config def parse_endpoint_id(endpoint_id: str) -> str: - return f"{dify_config.CONSOLE_API_URL}/console/api/trigger/endpoint/{endpoint_id}" + return f"{dify_config.CONSOLE_API_URL}/triggers/plugin/{endpoint_id}" diff --git a/api/migrations/versions/2025_09_04_1212-86f068bf56fb_plugin_trigger_workflow.py b/api/migrations/versions/2025_09_04_1212-86f068bf56fb_plugin_trigger_workflow.py new file mode 100644 index 0000000000..58f6ef07ed --- /dev/null +++ b/api/migrations/versions/2025_09_04_1212-86f068bf56fb_plugin_trigger_workflow.py @@ -0,0 +1,62 @@ +"""plugin_trigger_workflow + +Revision ID: 86f068bf56fb +Revises: 132392a2635f +Create Date: 2025-09-04 12:12:44.661875 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '86f068bf56fb' +down_revision = '132392a2635f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op: + batch_op.add_column(sa.Column('subscription_id', sa.String(length=255), nullable=False)) + batch_op.alter_column('provider_id', + existing_type=sa.VARCHAR(length=255), + type_=sa.String(length=512), + existing_nullable=False) + batch_op.alter_column('trigger_id', + existing_type=sa.VARCHAR(length=510), + type_=sa.String(length=255), + existing_nullable=False) + batch_op.drop_constraint(batch_op.f('uniq_plugin_node'), type_='unique') + batch_op.drop_constraint(batch_op.f('uniq_trigger_node'), type_='unique') + batch_op.drop_index(batch_op.f('workflow_plugin_trigger_tenant_idx')) + batch_op.drop_index(batch_op.f('workflow_plugin_trigger_trigger_idx')) + batch_op.create_unique_constraint('uniq_app_node_subscription', ['app_id', 'node_id']) + batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id'], unique=False) + batch_op.drop_column('triggered_by') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op: + batch_op.add_column(sa.Column('triggered_by', sa.VARCHAR(length=16), autoincrement=False, nullable=False)) + batch_op.drop_index('workflow_plugin_trigger_tenant_subscription_idx') + batch_op.drop_constraint('uniq_app_node_subscription', type_='unique') + batch_op.create_index(batch_op.f('workflow_plugin_trigger_trigger_idx'), ['trigger_id'], unique=False) + batch_op.create_index(batch_op.f('workflow_plugin_trigger_tenant_idx'), ['tenant_id'], unique=False) + batch_op.create_unique_constraint(batch_op.f('uniq_trigger_node'), ['trigger_id', 'node_id'], postgresql_nulls_not_distinct=False) + batch_op.create_unique_constraint(batch_op.f('uniq_plugin_node'), ['app_id', 'node_id', 'triggered_by'], postgresql_nulls_not_distinct=False) + batch_op.alter_column('trigger_id', + existing_type=sa.String(length=255), + type_=sa.VARCHAR(length=510), + existing_nullable=False) + batch_op.alter_column('provider_id', + existing_type=sa.String(length=512), + type_=sa.VARCHAR(length=255), + existing_nullable=False) + batch_op.drop_column('subscription_id') + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index 58ec685dd2..b2edf676dc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1436,8 +1436,8 @@ class WorkflowPluginTrigger(Base): - node_id (varchar) Node ID which node in the workflow - tenant_id (uuid) Workspace ID - provider_id (varchar) Plugin provider ID - - trigger_id (varchar) Unique trigger identifier (provider_id + trigger_name) - - triggered_by (varchar) Environment: debugger or production + - trigger_id (varchar) trigger id (github_issues_trigger) + - subscription_id (varchar) Subscription ID - created_at (timestamp) Creation time - updated_at (timestamp) Last update time """ @@ -1445,19 +1445,17 @@ class WorkflowPluginTrigger(Base): __tablename__ = "workflow_plugin_triggers" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="workflow_plugin_trigger_pkey"), - sa.Index("workflow_plugin_trigger_tenant_idx", "tenant_id"), - sa.Index("workflow_plugin_trigger_trigger_idx", "trigger_id"), - sa.UniqueConstraint("app_id", "node_id", "triggered_by", name="uniq_plugin_node"), - sa.UniqueConstraint("trigger_id", "node_id", name="uniq_trigger_node"), + sa.Index("workflow_plugin_trigger_tenant_subscription_idx", "tenant_id", "subscription_id"), + sa.UniqueConstraint("app_id", "node_id", name="uniq_app_node_subscription"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) node_id: Mapped[str] = mapped_column(String(64), nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_id: Mapped[str] = mapped_column(String(255), nullable=False) - trigger_id: Mapped[str] = mapped_column(String(510), nullable=False) # provider_id + trigger_name - triggered_by: Mapped[str] = mapped_column(String(16), nullable=False) + provider_id: Mapped[str] = mapped_column(String(512), nullable=False) + trigger_id: Mapped[str] = mapped_column(String(255), nullable=False) + subscription_id: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( DateTime, diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 2b1f75c63b..0431956bd5 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -105,7 +105,7 @@ class TriggerProviderService: # Check provider count limit provider_count = ( session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=provider_id) + .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) .count() ) @@ -118,7 +118,7 @@ class TriggerProviderService: # Check if name already exists existing = ( session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name) + .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) .first() ) if existing: @@ -136,7 +136,7 @@ class TriggerProviderService: user_id=user_id, name=name, endpoint_id=endpoint_id, - provider_id=provider_id, + provider_id=str(provider_id), parameters=parameters, properties=properties, credentials=encrypter.encrypt(dict(credentials)), @@ -447,5 +447,5 @@ class TriggerProviderService: Get a trigger subscription by the endpoint ID. """ with Session(db.engine, autoflush=False) as session: - subscription = session.query(TriggerSubscription).filter_by(endpoint=endpoint_id).first() + subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first() return subscription diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index cf7f564d68..745c39370a 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -2,16 +2,22 @@ import json import logging import uuid from collections.abc import Mapping +from typing import Any from flask import Request, Response from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.errors import ToolProviderCredentialValidationError +from core.trigger.entities.api_entities import SubscriptionBuilderApiEntity from core.trigger.entities.entities import ( RequestLog, SubscriptionBuilder, ) +from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager +from core.trigger.utils.encryption import masked_credentials +from core.trigger.utils.endpoint import parse_endpoint_id from extensions.ext_redis import redis_client from services.trigger.trigger_provider_service import TriggerProviderService @@ -43,7 +49,7 @@ class TriggerSubscriptionBuilderService: user_id: str, provider_id: TriggerProviderID, subscription_builder_id: str, - ) -> None: + ) -> Mapping[str, Any]: """Verify a trigger subscription builder""" provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) if not provider_controller: @@ -53,7 +59,17 @@ class TriggerSubscriptionBuilderService: if not subscription_builder: raise ValueError(f"Subscription builder {subscription_builder_id} not found") - provider_controller.validate_credentials(user_id, subscription_builder.credentials) + if subscription_builder.credential_type == CredentialType.OAUTH2: + return {"verified": bool(subscription_builder.credentials)} + + if subscription_builder.credential_type == CredentialType.API_KEY: + try: + provider_controller.validate_credentials(user_id, subscription_builder.credentials) + return {"verified": True} + except ToolProviderCredentialValidationError as e: + raise ValueError(f"Invalid credentials: {e}") + + return {"verified": True} @classmethod def build_trigger_subscription_builder( @@ -72,7 +88,7 @@ class TriggerSubscriptionBuilderService: if not subscription_builder: raise ValueError(f"Subscription builder {subscription_builder_id} not found") - if subscription_builder.name is None: + if not subscription_builder.name: raise ValueError("Subscription builder name is required") credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value) @@ -97,7 +113,7 @@ class TriggerSubscriptionBuilderService: tenant_id=tenant_id, user_id=user_id, provider_id=provider_id, - endpoint=subscription_builder.endpoint_id, + endpoint=parse_endpoint_id(subscription_builder.endpoint_id), parameters=subscription_builder.parameters, credentials=subscription_builder.credentials, ) @@ -162,21 +178,57 @@ class TriggerSubscriptionBuilderService: def update_trigger_subscription_builder( cls, tenant_id: str, - subscription_builder: SubscriptionBuilder, - ) -> SubscriptionBuilder: + provider_id: TriggerProviderID, + subscription_builder_id: str, + name: str | None, + parameters: Mapping[str, Any] | None, + properties: Mapping[str, Any] | None, + credentials: Mapping[str, str] | None, + ) -> SubscriptionBuilderApiEntity: """ Update a trigger subscription validation. """ - subscription_id = subscription_builder.id + subscription_id = subscription_builder_id + provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) + if not provider_controller: + raise ValueError(f"Provider {provider_id} not found") + cache_key = cls.encode_cache_key(subscription_id) - subscription_builder_cache = cls.get_subscription_builder(subscription_id) - if not subscription_builder_cache or subscription_builder_cache.tenant_id != tenant_id: - raise ValueError(f"Subscription {subscription_id} not found") + subscription_builder = cls.get_subscription_builder(subscription_id) + if not subscription_builder or subscription_builder.tenant_id != tenant_id: + raise ValueError(f"Subscription {subscription_id} expired or not found") + + if name: + subscription_builder.name = name + if parameters: + subscription_builder.parameters = parameters + if properties: + subscription_builder.properties = properties + if credentials: + subscription_builder.credentials = credentials redis_client.setex( cache_key, cls.__VALIDATION_REQUEST_CACHE_EXPIRE_MS__, subscription_builder.model_dump_json() ) - return subscription_builder + return cls.builder_to_api_entity(controller=provider_controller, entity=subscription_builder) + + @classmethod + def builder_to_api_entity( + cls, controller: PluginTriggerProviderController, entity: SubscriptionBuilder + ) -> SubscriptionBuilderApiEntity: + return SubscriptionBuilderApiEntity( + id=entity.id, + name=entity.name or "", + provider=entity.provider_id, + endpoint=parse_endpoint_id(entity.endpoint_id), + parameters=entity.parameters, + properties=entity.properties, + credential_type=CredentialType(entity.credential_type), + credentials=masked_credentials( + schemas=controller.get_credentials_schema(CredentialType(entity.credential_type)), + credentials=entity.credentials, + ), + ) @classmethod def delete_trigger_subscription_builder(cls, subscription_id: str) -> None: diff --git a/api/services/workflow_plugin_trigger_service.py b/api/services/workflow_plugin_trigger_service.py new file mode 100644 index 0000000000..337a79ec8b --- /dev/null +++ b/api/services/workflow_plugin_trigger_service.py @@ -0,0 +1,376 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session +from werkzeug.exceptions import BadRequest, NotFound + +from extensions.ext_database import db +from models.workflow import WorkflowPluginTrigger + + +class WorkflowPluginTriggerService: + """Service for managing workflow plugin triggers""" + + @classmethod + def create_plugin_trigger( + cls, + app_id: str, + tenant_id: str, + node_id: str, + provider_id: str, + trigger_name: str, + subscription_id: str, + ) -> WorkflowPluginTrigger: + """Create a new plugin trigger + + Args: + app_id: The app ID + tenant_id: The tenant ID + node_id: The node ID in the workflow + provider_id: The plugin provider ID + trigger_name: The trigger name + subscription_id: The subscription ID + + Returns: + The created WorkflowPluginTrigger instance + + Raises: + BadRequest: If plugin trigger already exists for this app and node + """ + # Create trigger_id from provider_id and trigger_name + trigger_id = f"{provider_id}:{trigger_name}" + + with Session(db.engine) as session: + # Check if plugin trigger already exists for this app and node + # Based on unique constraint: uniq_app_node + existing_trigger = session.scalar( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.app_id == app_id, + WorkflowPluginTrigger.node_id == node_id, + ) + ) + + if existing_trigger: + raise BadRequest("Plugin trigger already exists for this app and node") + + # Create new plugin trigger + plugin_trigger = WorkflowPluginTrigger( + app_id=app_id, + node_id=node_id, + tenant_id=tenant_id, + provider_id=provider_id, + trigger_id=trigger_id, + subscription_id=subscription_id, + ) + + session.add(plugin_trigger) + session.commit() + session.refresh(plugin_trigger) + + return plugin_trigger + + @classmethod + def get_plugin_trigger( + cls, + app_id: str, + node_id: str, + ) -> WorkflowPluginTrigger: + """Get a plugin trigger by app_id and node_id + + Args: + app_id: The app ID + node_id: The node ID in the workflow + + Returns: + The WorkflowPluginTrigger instance + + Raises: + NotFound: If plugin trigger not found + """ + with Session(db.engine) as session: + # Find plugin trigger using unique constraint + plugin_trigger = session.scalar( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.app_id == app_id, + WorkflowPluginTrigger.node_id == node_id, + ) + ) + + if not plugin_trigger: + raise NotFound("Plugin trigger not found") + + return plugin_trigger + + @classmethod + def get_plugin_trigger_by_subscription( + cls, + tenant_id: str, + subscription_id: str, + ) -> WorkflowPluginTrigger: + """Get a plugin trigger by tenant_id and subscription_id + This is the primary query pattern, optimized with composite index + + Args: + tenant_id: The tenant ID + subscription_id: The subscription ID + + Returns: + The WorkflowPluginTrigger instance + + Raises: + NotFound: If plugin trigger not found + """ + with Session(db.engine) as session: + # Find plugin trigger using indexed columns + plugin_trigger = session.scalar( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.tenant_id == tenant_id, + WorkflowPluginTrigger.subscription_id == subscription_id, + ) + ) + + if not plugin_trigger: + raise NotFound("Plugin trigger not found") + + return plugin_trigger + + @classmethod + def list_plugin_triggers_by_tenant( + cls, + tenant_id: str, + ) -> list[WorkflowPluginTrigger]: + """List all plugin triggers for a tenant + + Args: + tenant_id: The tenant ID + + Returns: + List of WorkflowPluginTrigger instances + """ + with Session(db.engine) as session: + plugin_triggers = session.scalars( + select(WorkflowPluginTrigger) + .where(WorkflowPluginTrigger.tenant_id == tenant_id) + .order_by(WorkflowPluginTrigger.created_at.desc()) + ).all() + + return list(plugin_triggers) + + @classmethod + def list_plugin_triggers_by_subscription( + cls, + subscription_id: str, + ) -> list[WorkflowPluginTrigger]: + """List all plugin triggers for a subscription + + Args: + subscription_id: The subscription ID + + Returns: + List of WorkflowPluginTrigger instances + """ + with Session(db.engine) as session: + plugin_triggers = session.scalars( + select(WorkflowPluginTrigger) + .where(WorkflowPluginTrigger.subscription_id == subscription_id) + .order_by(WorkflowPluginTrigger.created_at.desc()) + ).all() + + return list(plugin_triggers) + + @classmethod + def update_plugin_trigger( + cls, + app_id: str, + node_id: str, + provider_id: Optional[str] = None, + trigger_name: Optional[str] = None, + subscription_id: Optional[str] = None, + ) -> WorkflowPluginTrigger: + """Update a plugin trigger + + Args: + app_id: The app ID + node_id: The node ID in the workflow + provider_id: The new provider ID (optional) + trigger_name: The new trigger name (optional) + subscription_id: The new subscription ID (optional) + + Returns: + The updated WorkflowPluginTrigger instance + + Raises: + NotFound: If plugin trigger not found + """ + with Session(db.engine) as session: + # Find plugin trigger using unique constraint + plugin_trigger = session.scalar( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.app_id == app_id, + WorkflowPluginTrigger.node_id == node_id, + ) + ) + + if not plugin_trigger: + raise NotFound("Plugin trigger not found") + + # Update fields if provided + if provider_id: + plugin_trigger.provider_id = provider_id + + if trigger_name: + # Update trigger_id if provider_id or trigger_name changed + provider_id = provider_id or plugin_trigger.provider_id + plugin_trigger.trigger_id = f"{provider_id}:{trigger_name}" + + if subscription_id: + plugin_trigger.subscription_id = subscription_id + + session.commit() + session.refresh(plugin_trigger) + + return plugin_trigger + + @classmethod + def update_plugin_trigger_by_subscription( + cls, + tenant_id: str, + subscription_id: str, + provider_id: Optional[str] = None, + trigger_name: Optional[str] = None, + new_subscription_id: Optional[str] = None, + ) -> WorkflowPluginTrigger: + """Update a plugin trigger by tenant_id and subscription_id + + Args: + tenant_id: The tenant ID + subscription_id: The current subscription ID + provider_id: The new provider ID (optional) + trigger_name: The new trigger name (optional) + new_subscription_id: The new subscription ID (optional) + + Returns: + The updated WorkflowPluginTrigger instance + + Raises: + NotFound: If plugin trigger not found + """ + with Session(db.engine) as session: + # Find plugin trigger using indexed columns + plugin_trigger = session.scalar( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.tenant_id == tenant_id, + WorkflowPluginTrigger.subscription_id == subscription_id, + ) + ) + + if not plugin_trigger: + raise NotFound("Plugin trigger not found") + + # Update fields if provided + if provider_id: + plugin_trigger.provider_id = provider_id + + if trigger_name: + # Update trigger_id if provider_id or trigger_name changed + provider_id = provider_id or plugin_trigger.provider_id + plugin_trigger.trigger_id = f"{provider_id}:{trigger_name}" + + if new_subscription_id: + plugin_trigger.subscription_id = new_subscription_id + + session.commit() + session.refresh(plugin_trigger) + + return plugin_trigger + + @classmethod + def delete_plugin_trigger( + cls, + app_id: str, + node_id: str, + ) -> None: + """Delete a plugin trigger by app_id and node_id + + Args: + app_id: The app ID + node_id: The node ID in the workflow + + Raises: + NotFound: If plugin trigger not found + """ + with Session(db.engine) as session: + # Find plugin trigger using unique constraint + plugin_trigger = session.scalar( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.app_id == app_id, + WorkflowPluginTrigger.node_id == node_id, + ) + ) + + if not plugin_trigger: + raise NotFound("Plugin trigger not found") + + session.delete(plugin_trigger) + session.commit() + + @classmethod + def delete_plugin_trigger_by_subscription( + cls, + tenant_id: str, + subscription_id: str, + ) -> None: + """Delete a plugin trigger by tenant_id and subscription_id + + Args: + tenant_id: The tenant ID + subscription_id: The subscription ID + + Raises: + NotFound: If plugin trigger not found + """ + with Session(db.engine) as session: + # Find plugin trigger using indexed columns + plugin_trigger = session.scalar( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.tenant_id == tenant_id, + WorkflowPluginTrigger.subscription_id == subscription_id, + ) + ) + + if not plugin_trigger: + raise NotFound("Plugin trigger not found") + + session.delete(plugin_trigger) + session.commit() + + @classmethod + def delete_all_by_subscription( + cls, + subscription_id: str, + ) -> int: + """Delete all plugin triggers for a subscription + Useful when a subscription is cancelled + + Args: + subscription_id: The subscription ID + + Returns: + Number of triggers deleted + """ + with Session(db.engine) as session: + # Find all plugin triggers for this subscription + plugin_triggers = session.scalars( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.subscription_id == subscription_id, + ) + ).all() + + count = len(plugin_triggers) + + for trigger in plugin_triggers: + session.delete(trigger) + + session.commit() + + return count