feat(trigger): add trigger provider management and webhook handling functionality

This commit is contained in:
Harry 2025-08-28 11:46:35 +08:00
parent 7544b5ec9a
commit 87120ad4ac
28 changed files with 2056 additions and 57 deletions

View File

@ -59,6 +59,7 @@ pnpm test # Run Jest tests
- Use type hints for all functions and class attributes
- No `Any` types unless absolutely necessary
- Implement special methods (`__repr__`, `__str__`) appropriately
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
### TypeScript/JavaScript

View File

@ -1207,6 +1207,55 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_trigger_oauth_client(provider, client_params):
"""
Setup system trigger oauth client
"""
from core.plugin.entities.plugin import TriggerProviderID
from models.trigger import TriggerOAuthSystemClient
provider_id = TriggerProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(TriggerOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = TriggerOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
"""
Find draft variables that reference non-existent apps.

View File

@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService

View File

@ -0,0 +1,323 @@
import logging
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, Forbidden
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from libs.login import current_user, login_required
from models.account import Account
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
return TriggerProviderService.list_trigger_providers(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e)
raise
class TriggerProviderCredentialsAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Add a new credential instance for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("expires_at", type=int, required=False, nullable=True, location="json", default=-1)
args = parser.parse_args()
try:
# Parse credential type
try:
credential_type = CredentialType(args["credential_type"])
except ValueError:
raise BadRequest(f"Invalid credential_type. Must be one of: {[t.value for t in CredentialType]}")
result = TriggerProviderService.add_trigger_provider(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
credential_type=credential_type,
credentials=args["credentials"],
name=args.get("name"),
expires_at=args.get("expires_at", -1),
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error adding provider credential", exc_info=e)
raise
class TriggerProviderCredentialsUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, credential_id):
"""Update an existing credential instance"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
result = TriggerProviderService.update_trigger_provider(
tenant_id=user.current_tenant_id,
credential_id=credential_id,
credentials=args.get("credentials"),
name=args.get("name"),
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error updating provider credential", exc_info=e)
raise
class TriggerProviderCredentialsDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, credential_id):
"""Delete a credential instance"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
result = TriggerProviderService.delete_trigger_provider(
tenant_id=user.current_tenant_id,
credential_id=credential_id,
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error deleting provider credential", exc_info=e)
raise
class TriggerProviderOAuthAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Initiate OAuth authorization flow for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
context_id = TriggerProviderService.create_oauth_proxy_context(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
)
# TODO: Build OAuth authorization URL
# This will be implemented when we have provider-specific OAuth configs
return {
"context_id": context_id,
"authorization_url": f"/oauth/authorize?context={context_id}",
}
except Exception as e:
logger.exception("Error initiating OAuth flow", exc_info=e)
raise
class TriggerProviderOAuthRefreshTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, credential_id):
"""Refresh OAuth token for a trigger provider credential"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
result = TriggerProviderService.refresh_oauth_token(
tenant_id=user.current_tenant_id,
credential_id=credential_id,
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error refreshing OAuth token", exc_info=e)
raise
class TriggerProviderOAuthClientManageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
"""Get OAuth client configuration for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
# Get custom OAuth client params if exists
custom_params = TriggerProviderService.get_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if custom client is enabled
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
# Check if there's a system OAuth client
system_client = TriggerProviderService.get_oauth_client(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
return {
"configured": bool(custom_params or system_client),
"custom_configured": bool(custom_params),
"custom_enabled": is_custom_enabled,
"params": custom_params if custom_params else {},
}
except Exception as e:
logger.exception("Error getting OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
"""Configure custom OAuth client for a provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
try:
provider_id = TriggerProviderID(provider)
result = TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error configuring OAuth client", exc_info=e)
raise
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
"""Remove custom OAuth client configuration"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try:
provider_id = TriggerProviderID(provider)
result = TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e)
raise
# Trigger provider endpoints
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-provider/<path:provider>/list")
api.add_resource(TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/<path:provider>/add")
api.add_resource(
TriggerProviderCredentialsUpdateApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/update"
)
api.add_resource(
TriggerProviderCredentialsDeleteApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/delete"
)
api.add_resource(
TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
)
api.add_resource(
TriggerProviderOAuthRefreshTokenApi,
"/workspaces/current/trigger-provider/credentials/<path:credential_id>/oauth/refresh",
)
api.add_resource(
TriggerProviderOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
)

View File

@ -4,4 +4,4 @@ from flask import Blueprint
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
# Import routes after blueprint creation to avoid circular imports
from . import webhook
from . import trigger, webhook

View File

@ -0,0 +1,23 @@
import logging
from flask import jsonify, request
from werkzeug.exceptions import NotFound
from controllers.trigger import bp
from services.trigger_service import TriggerService
logger = logging.getLogger(__name__)
@bp.route("/trigger/webhook/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def trigger_webhook(endpoint_id: str):
"""
Handle webhook trigger calls.
"""
try:
return TriggerService.process_webhook(endpoint_id, request)
except ValueError as e:
raise NotFound(str(e))
except Exception as e:
logger.exception("Webhook processing failed for {endpoint_id}")
return jsonify({"error": "Internal server error", "message": str(e)}), 500

View File

@ -68,6 +68,19 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class TriggerProviderCredentialCache(ProviderCredentialsCache):
"""Cache for trigger provider credentials"""
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]
credential_id = kwargs["credential_id"]
return f"trigger_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""

View File

@ -184,6 +184,10 @@ class ToolProviderID(GenericProviderID):
self.plugin_name = f"{self.provider_name}_tool"
class TriggerProviderID(GenericProviderID):
pass
class PluginDependency(BaseModel):
class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value

View File

@ -1,3 +1,4 @@
import enum
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
@ -13,6 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
@ -196,3 +198,43 @@ class PluginListResponse(BaseModel):
class PluginDynamicSelectOptionsResponse(BaseModel):
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
class PluginTriggerProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
declaration: TriggerProviderEntity
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@ -4,10 +4,10 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.plugin.utils.chunk_merger import merge_blob_chunks
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient):

View File

@ -0,0 +1,68 @@
from typing import Any
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity, PluginTriggerProviderEntity
from core.plugin.impl.base import BasePluginClient
class PluginTriggerManager(BasePluginClient):
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
"""
Fetch tool providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tools",
list[PluginToolProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.tools:
tool.identity.provider = provider.declaration.identity.name
return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
"""
tool_provider_id = ToolProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for tool in data.get("declaration", {}).get("tools", []):
tool["identity"]["provider"] = tool_provider_id.provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tool",
PluginToolProviderEntity,
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in response.declaration.tools:
tool.identity.provider = response.declaration.identity.name
return response

View File

@ -4,7 +4,8 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.entities.tool_entities import ToolInvokeFrom
class ToolRuntime(BaseModel):

View File

@ -4,11 +4,11 @@ from typing import Any
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,

View File

@ -4,9 +4,10 @@ from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
from core.tools.entities.tool_entities import ToolProviderType
class ToolApiEntity(BaseModel):

View File

@ -476,36 +476,3 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@ -37,6 +37,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import Tool
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -47,7 +48,6 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,

View File

@ -0,0 +1 @@
# Core trigger module initialization

View File

@ -0,0 +1,244 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from core.tools.entities.common_entities import I18nObject
class TriggerParameterOption(BaseModel):
"""
The option of the trigger parameter
"""
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
class TriggerParameterType(StrEnum):
"""The type of the parameter"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
FILE = "file"
FILES = "files"
MODEL_SELECTOR = "model-selector"
APP_SELECTOR = "app-selector"
OBJECT = "object"
ARRAY = "array"
DYNAMIC_SELECT = "dynamic-select"
class ParameterAutoGenerate(BaseModel):
"""Auto generation configuration for parameters"""
enabled: bool = Field(default=False, description="Whether auto generation is enabled")
template: Optional[str] = Field(default=None, description="Template for auto generation")
class ParameterTemplate(BaseModel):
"""Template configuration for parameters"""
value: str = Field(..., description="Template value")
type: str = Field(default="jinja2", description="Template type")
class TriggerParameter(BaseModel):
"""
The parameter of the trigger
"""
name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user")
type: TriggerParameterType = Field(..., description="The type of the parameter")
auto_generate: Optional[ParameterAutoGenerate] = Field(
default=None, description="The auto generate of the parameter"
)
template: Optional[ParameterTemplate] = Field(default=None, description="The template of the parameter")
scope: Optional[str] = None
required: Optional[bool] = False
default: Union[int, float, str, None] = None
min: Union[float, int, None] = None
max: Union[float, int, None] = None
precision: Optional[int] = None
options: Optional[list[TriggerParameterOption]] = None
description: Optional[I18nObject] = None
class TriggerProviderIdentity(BaseModel):
"""
The identity of the trigger provider
"""
author: str = Field(..., description="The author of the trigger provider")
name: str = Field(..., description="The name of the trigger provider")
label: I18nObject = Field(..., description="The label of the trigger provider")
description: I18nObject = Field(..., description="The description of the trigger provider")
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
class TriggerIdentity(BaseModel):
"""
The identity of the trigger
"""
author: str = Field(..., description="The author of the trigger")
name: str = Field(..., description="The name of the trigger")
label: I18nObject = Field(..., description="The label of the trigger")
class TriggerDescription(BaseModel):
"""
The description of the trigger
"""
human: I18nObject = Field(..., description="Human readable description")
llm: I18nObject = Field(..., description="LLM readable description")
class TriggerConfigurationExtraPython(BaseModel):
"""Python configuration for trigger"""
source: str = Field(..., description="The source file path for the trigger implementation")
class TriggerConfigurationExtra(BaseModel):
"""
The extra configuration for trigger
"""
class TriggerEntity(BaseModel):
"""
The configuration of a trigger
"""
python: TriggerConfigurationExtraPython
identity: TriggerIdentity = Field(..., description="The identity of the trigger")
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
description: TriggerDescription = Field(..., description="The description of the trigger")
extra: TriggerConfigurationExtra = Field(..., description="The extra configuration of the trigger")
output_schema: Optional[Mapping[str, Any]] = Field(
default=None, description="The output schema that this trigger produces"
)
class TriggerProviderConfigurationExtraPython(BaseModel):
"""Python configuration for trigger provider"""
source: str = Field(..., description="The source file path for the trigger provider implementation")
class TriggerProviderConfigurationExtra(BaseModel):
"""
The extra configuration for trigger provider
"""
python: TriggerProviderConfigurationExtraPython
class OAuthSchema(BaseModel):
"""OAuth configuration schema"""
authorization_url: str = Field(..., description="OAuth authorization URL")
token_url: str = Field(..., description="OAuth token URL")
client_id: str = Field(..., description="OAuth client ID")
client_secret: str = Field(..., description="OAuth client secret")
redirect_uri: Optional[str] = Field(default=None, description="OAuth redirect URI")
scope: Optional[str] = Field(default=None, description="OAuth scope")
class ProviderConfig(BaseModel):
"""Provider configuration item"""
name: str = Field(..., description="Configuration field name")
type: str = Field(..., description="Configuration field type")
required: bool = Field(default=False, description="Whether this field is required")
default: Any = Field(default=None, description="Default value")
label: Optional[I18nObject] = Field(default=None, description="Field label")
description: Optional[I18nObject] = Field(default=None, description="Field description")
options: Optional[list[dict[str, Any]]] = Field(default=None, description="Options for select type")
class TriggerProviderEntity(BaseModel):
"""
The configuration of a trigger provider
"""
identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list,
description="The credentials schema of the trigger provider",
)
oauth_schema: Optional[OAuthSchema] = Field(
default=None,
description="The OAuth schema of the trigger provider if OAuth is supported",
)
subscription_schema: list[ProviderConfig] = Field(
default_factory=list,
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
)
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
extra: TriggerProviderConfigurationExtra = Field(..., description="The extra configuration of the trigger provider")
class Subscription(BaseModel):
"""
Result of a successful trigger subscription operation.
Contains all information needed to manage the subscription lifecycle.
"""
expire_at: int = Field(
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
)
metadata: dict[str, Any] = Field(
..., description="Metadata about the subscription in the external service, defined in subscription_schema"
)
class Unsubscription(BaseModel):
"""
Result of a trigger unsubscription operation.
Provides detailed information about the unsubscription attempt,
including success status and error details if failed.
"""
success: bool = Field(..., description="Whether the unsubscription was successful")
message: Optional[str] = Field(
None,
description="Human-readable message about the operation result. "
"Success message for successful operations, "
"detailed error information for failures.",
)
# Export all entities
__all__ = [
"OAuthSchema",
"ParameterAutoGenerate",
"ParameterTemplate",
"ProviderConfig",
"Subscription",
"TriggerConfigurationExtra",
"TriggerConfigurationExtraPython",
"TriggerDescription",
"TriggerEntity",
"TriggerEntity",
"TriggerIdentity",
"TriggerParameter",
"TriggerParameterOption",
"TriggerParameterType",
"TriggerProviderConfigurationExtra",
"TriggerProviderConfigurationExtraPython",
"TriggerProviderEntity",
"TriggerProviderIdentity",
"Unsubscription",
]

View File

@ -0,0 +1,199 @@
"""
Trigger Provider Controller for managing trigger providers
"""
import logging
import time
from typing import Optional
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities import (
ProviderConfig,
Subscription,
TriggerEntity,
TriggerProviderEntity,
TriggerProviderIdentity,
Unsubscription,
)
logger = logging.getLogger(__name__)
class TriggerProviderController:
"""
Controller for plugin trigger providers
"""
def __init__(
self,
entity: TriggerProviderEntity,
plugin_id: str,
plugin_unique_identifier: str,
tenant_id: str,
):
"""
Initialize plugin trigger provider controller
:param entity: Trigger provider entity
:param plugin_id: Plugin ID
:param plugin_unique_identifier: Plugin unique identifier
:param tenant_id: Tenant ID
"""
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def identity(self) -> TriggerProviderIdentity:
"""Get provider identity"""
return self.entity.identity
def get_triggers(self) -> list[TriggerEntity]:
"""
Get all triggers for this provider
:return: List of trigger entities
"""
return self.entity.triggers
def get_trigger(self, trigger_name: str) -> Optional[TriggerEntity]:
"""
Get a specific trigger by name
:param trigger_name: Trigger name
:return: Trigger entity or None
"""
for trigger in self.entity.triggers:
if trigger.identity.name == trigger_name:
return trigger
return None
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
Get credentials schema for this provider
:return: List of provider config schemas
"""
return self.entity.credentials_schema
def get_subscription_schema(self) -> list[ProviderConfig]:
"""
Get subscription schema for this provider
:return: List of subscription config schemas
"""
return self.entity.subscription_schema
def validate_credentials(self, credentials: dict) -> None:
"""
Validate credentials against schema
:param credentials: Credentials to validate
:raises ValueError: If credentials are invalid
"""
for config in self.entity.credentials_schema:
if config.required and config.name not in credentials:
raise ValueError(f"Missing required credential field: {config.name}")
def get_supported_credential_types(self) -> list[CredentialType]:
"""
Get supported credential types for this provider.
:return: List of supported credential types
"""
types = []
if self.entity.oauth_schema:
types.append(CredentialType.OAUTH2)
if self.entity.credentials_schema:
types.append(CredentialType.API_KEY)
return types
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
Get credentials schema by credential type
:param credential_type: The type of credential (oauth or api_key)
:return: List of provider config schemas
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
Get OAuth client schema for this provider
:return: List of OAuth client config schemas
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
@property
def need_credentials(self) -> bool:
"""Check if this provider needs credentials"""
return len(self.get_supported_credential_types()) > 0
def execute_trigger(self, trigger_name: str, parameters: dict, credentials: dict) -> dict:
"""
Execute a trigger through plugin runtime
:param trigger_name: Trigger name
:param parameters: Trigger parameters
:param credentials: Provider credentials
:return: Execution result
"""
logger.info("Executing trigger %s for plugin %s", trigger_name, self.plugin_id)
return {
"success": True,
"trigger": trigger_name,
"plugin": self.plugin_id,
"result": "Trigger executed successfully",
}
def subscribe_trigger(self, trigger_name: str, subscription_params: dict, credentials: dict) -> Subscription:
"""
Subscribe to a trigger through plugin runtime
:param trigger_name: Trigger name
:param subscription_params: Subscription parameters
:param credentials: Provider credentials
:return: Subscription result
"""
logger.info("Subscribing to trigger %s for plugin %s", trigger_name, self.plugin_id)
return Subscription(
expire_at=int(time.time()) + 86400, # 24 hours from now
metadata={
"subscription_id": f"{self.plugin_id}_{trigger_name}_{time.time()}",
"webhook_url": f"/triggers/webhook/{self.plugin_id}/{trigger_name}",
**subscription_params,
},
)
def unsubscribe_trigger(self, trigger_name: str, subscription_metadata: dict, credentials: dict) -> Unsubscription:
"""
Unsubscribe from a trigger through plugin runtime
:param trigger_name: Trigger name
:param subscription_metadata: Subscription metadata
:param credentials: Provider credentials
:return: Unsubscription result
"""
logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id)
return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}")
def handle_webhook(self, webhook_path: str, request_data: dict, credentials: dict) -> dict:
"""
Handle incoming webhook through plugin runtime
:param webhook_path: Webhook path
:param request_data: Request data
:param credentials: Provider credentials
:return: Webhook handling result
"""
logger.info("Handling webhook for path %s for plugin %s", webhook_path, self.plugin_id)
return {"success": True, "path": webhook_path, "plugin": self.plugin_id, "data_received": request_data}
__all__ = ["TriggerProviderController"]

View File

@ -0,0 +1,360 @@
"""
Trigger Manager for loading and managing trigger providers and triggers
"""
import logging
from typing import Optional
from core.trigger.entities import (
ProviderConfig,
TriggerEntity,
)
from core.trigger.plugin_trigger import PluginTriggerController
from core.trigger.provider import PluginTriggerProviderController
logger = logging.getLogger(__name__)
class TriggerManager:
"""
Manager for trigger providers and triggers
"""
@classmethod
def list_plugin_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
"""
List all plugin trigger providers for a tenant
:param tenant_id: Tenant ID
:return: List of trigger provider controllers
"""
manager = PluginTriggerController()
provider_entities = manager.fetch_trigger_providers(tenant_id)
controllers = []
for provider in provider_entities:
try:
controller = PluginTriggerProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
controllers.append(controller)
except Exception as e:
logger.exception("Failed to load trigger provider {provider.plugin_id}")
continue
return controllers
@classmethod
def get_plugin_trigger_provider(
cls, tenant_id: str, plugin_id: str, provider_name: str
) -> Optional[PluginTriggerProviderController]:
"""
Get a specific plugin trigger provider
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: Trigger provider controller or None
"""
manager = PluginTriggerManager()
provider = manager.fetch_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return None
try:
return PluginTriggerProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception("Failed to load trigger provider")
return None
@classmethod
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
"""
List all trigger providers (plugin and builtin)
:param tenant_id: Tenant ID
:return: List of all trigger provider controllers
"""
providers = []
# Get plugin providers
plugin_providers = cls.list_plugin_trigger_providers(tenant_id)
providers.extend(plugin_providers)
# TODO: Add builtin providers when implemented
# builtin_providers = cls.list_builtin_trigger_providers(tenant_id)
# providers.extend(builtin_providers)
return providers
@classmethod
def list_triggers_by_provider(cls, tenant_id: str, plugin_id: str, provider_name: str) -> list[TriggerEntity]:
"""
List all triggers for a specific provider
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: List of trigger entities
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return []
return provider.get_triggers()
@classmethod
def get_trigger(
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str
) -> Optional[TriggerEntity]:
"""
Get a specific trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param trigger_name: Trigger name
:return: Trigger entity or None
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return None
return provider.get_trigger(trigger_name)
@classmethod
def validate_trigger_credentials(
cls, tenant_id: str, plugin_id: str, provider_name: str, credentials: dict
) -> tuple[bool, str]:
"""
Validate trigger provider credentials
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param credentials: Credentials to validate
:return: Tuple of (is_valid, error_message)
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return False, "Provider not found"
try:
provider.validate_credentials(credentials)
return True, ""
except Exception as e:
return False, str(e)
@classmethod
def execute_trigger(
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str, parameters: dict, credentials: dict
) -> dict:
"""
Execute a trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param trigger_name: Trigger name
:param parameters: Trigger parameters
:param credentials: Provider credentials
:return: Trigger execution result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
trigger = provider.get_trigger(trigger_name)
if not trigger:
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_name}")
return provider.execute_trigger(trigger_name, parameters, credentials)
@classmethod
def subscribe_trigger(
cls,
tenant_id: str,
plugin_id: str,
provider_name: str,
trigger_name: str,
subscription_params: dict,
credentials: dict,
) -> dict:
"""
Subscribe to a trigger (e.g., register webhook)
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param trigger_name: Trigger name
:param subscription_params: Subscription parameters
:param credentials: Provider credentials
:return: Subscription result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.subscribe_trigger(trigger_name, subscription_params, credentials)
@classmethod
def unsubscribe_trigger(
cls,
tenant_id: str,
plugin_id: str,
provider_name: str,
trigger_name: str,
subscription_metadata: dict,
credentials: dict,
) -> dict:
"""
Unsubscribe from a trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param trigger_name: Trigger name
:param subscription_metadata: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:return: Unsubscription result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.unsubscribe_trigger(trigger_name, subscription_metadata, credentials)
@classmethod
def handle_webhook(
cls,
tenant_id: str,
plugin_id: str,
provider_name: str,
webhook_path: str,
request_data: dict,
credentials: dict,
) -> dict:
"""
Handle incoming webhook for a trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param webhook_path: Webhook path
:param request_data: Webhook request data
:param credentials: Provider credentials
:return: Webhook handling result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.handle_webhook(webhook_path, request_data, credentials)
@classmethod
def get_provider_credentials_schema(
cls, tenant_id: str, plugin_id: str, provider_name: str
) -> list[ProviderConfig]:
"""
Get provider credentials schema
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: List of provider config schemas
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return []
return provider.get_credentials_schema()
@classmethod
def get_provider_subscription_schema(
cls, tenant_id: str, plugin_id: str, provider_name: str
) -> list[ProviderConfig]:
"""
Get provider subscription schema
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: List of subscription config schemas
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return []
return provider.get_subscription_schema()
@classmethod
def get_provider_info(cls, tenant_id: str, plugin_id: str, provider_name: str) -> Optional[dict]:
"""
Get provider information
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: Provider info dict or None
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return None
return {
"plugin_id": plugin_id,
"provider_name": provider_name,
"identity": provider.entity.identity.model_dump() if provider.entity.identity else {},
"credentials_schema": [c.model_dump() for c in provider.entity.credentials_schema],
"subscription_schema": [s.model_dump() for s in provider.entity.subscription_schema],
"oauth_enabled": provider.entity.oauth_schema is not None,
"trigger_count": len(provider.entity.triggers),
"triggers": [t.identity.model_dump() for t in provider.entity.triggers],
}
@classmethod
def list_providers_for_workflow(cls, tenant_id: str) -> list[dict]:
"""
List trigger providers suitable for workflow usage
:param tenant_id: Tenant ID
:return: List of provider info dicts
"""
providers = cls.list_all_trigger_providers(tenant_id)
result = []
for provider in providers:
info = {
"plugin_id": provider.plugin_id,
"provider_name": provider.entity.identity.name,
"label": provider.entity.identity.label.model_dump(),
"description": provider.entity.identity.description.model_dump(),
"icon": provider.entity.identity.icon,
"trigger_count": len(provider.entity.triggers),
}
result.append(info)
return result
# Export
__all__ = ["TriggerManager"]

View File

@ -20,6 +20,7 @@ def init_app(app: DifyApp):
reset_encrypt_key_pair,
reset_password,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
upgrade_db,
vdb_migrate,
)
@ -43,6 +44,7 @@ def init_app(app: DifyApp):
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
cleanup_orphaned_draft_variables,
]
for cmd in cmds_to_register:

97
api/models/trigger.py Normal file
View File

@ -0,0 +1,97 @@
import json
from datetime import UTC, datetime
from typing import cast
import sqlalchemy as sa
from sqlalchemy import DateTime, Index, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from core.plugin.entities.plugin_daemon import CredentialType
from models.base import Base
from models.types import StringUUID
class TriggerProvider(Base):
"""
Trigger provider model for managing credentials
Supports multiple credential instances per provider
"""
__tablename__ = "trigger_providers"
__table_args__ = (Index("idx_trigger_providers_tenant_provider", "tenant_id", "provider_id"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_id: Mapped[str] = mapped_column(
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
)
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
encrypted_credentials: Mapped[str] = mapped_column(Text, nullable=False, comment="Encrypted credentials JSON")
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Credential instance name")
expires_at: Mapped[int] = mapped_column(
Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never"
)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
)
@property
def credentials(self) -> dict:
"""Get credentials as dict (still encrypted)"""
try:
return json.loads(self.encrypted_credentials) if self.encrypted_credentials else {}
except (json.JSONDecodeError, TypeError):
return {}
@property
def credentials_str(self) -> str:
"""Get credentials as string"""
return self.encrypted_credentials or "{}"
def is_oauth_expired(self) -> bool:
"""Check if OAuth token is expired"""
if self.credential_type != CredentialType.OAUTH2.value:
return False
if self.expires_at == -1:
return False
# Check if token expires in next 60 seconds
return (self.expires_at - 60) < int(datetime.now(UTC).timestamp())
# system level trigger oauth client params
class TriggerOAuthSystemClient(Base):
__tablename__ = "trigger_oauth_system_clients"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trigger_oauth_system_client_pkey"),
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
# oauth params of the trigger provider
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
# tenant level trigger oauth client params (client_id, client_secret, etc.)
class TriggerOAuthTenantClient(Base):
__tablename__ = "trigger_oauth_tenant_clients"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"),
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
provider: Mapped[str] = mapped_column(String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
# oauth params of the trigger provider
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))

View File

@ -13,6 +13,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import (
@ -21,7 +22,6 @@ from core.tools.entities.api_entities import (
ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity,
)
from core.tools.entities.tool_entities import CredentialType
from core.tools.errors import ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
@ -39,7 +39,6 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
__DEFAULT_EXPIRES_AT__ = 2147483647
@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
@ -278,9 +277,7 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
expires_at=expires_at
if expires_at is not None
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
expires_at=expires_at if expires_at is not None else -1,
)
session.add(db_provider)

View File

@ -7,6 +7,7 @@ from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -16,7 +17,6 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolParameter,
ToolProviderType,
)

View File

@ -0,0 +1,588 @@
import json
import logging
import re
from collections.abc import Mapping
from typing import Any, Optional
from sqlalchemy import desc
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.service import PluginService
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.trigger_manager import TriggerManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider
from services.plugin.oauth_service import OAuthProxyService
logger = logging.getLogger(__name__)
class TriggerProviderService:
"""Service for managing trigger providers and credentials"""
__MAX_TRIGGER_PROVIDER_COUNT__ = 100
@classmethod
def list_trigger_providers(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerProvider]:
"""List all trigger providers for the current tenant"""
# TODO fetch trigger plugin controller
# TODO fetch all trigger plugin credentials
with Session(db.engine, autoflush=False) as session:
return session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).all()
@classmethod
def add_trigger_provider(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
credential_type: CredentialType,
credentials: dict,
name: Optional[str] = None,
expires_at: int = -1,
) -> dict:
"""
Add a new trigger provider with credentials.
Supports multiple credential instances per provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier (e.g., "plugin_id/provider_name")
:param credential_type: Type of credential (oauth or api_key)
:param credentials: Credential data to encrypt and store
:param name: Optional name for this credential instance
:param expires_at: OAuth token expiration timestamp
:return: Success response
"""
try:
with Session(db.engine) as session:
# Use distributed lock to prevent race conditions
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
with redis_client.lock(lock_key, timeout=20):
# Check provider count limit
provider_count = (
session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).count()
)
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
raise ValueError(
f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) "
f"reached for {provider_id}"
)
# Generate name if not provided
if not name:
name = cls._generate_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
else:
# Check if name already exists
existing = (
session.query(TriggerProvider)
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
.first()
)
if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider")
# Create encrypter for credentials
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[], # We'll define schema later in TriggerProvider classes
cache=NoOpProviderCredentialCache(),
)
# Create provider record
db_provider = TriggerProvider(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
credential_type=credential_type.value,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
name=name,
expires_at=expires_at,
)
session.add(db_provider)
session.commit()
return {"result": "success", "id": str(db_provider.id)}
except Exception as e:
logger.exception("Failed to add trigger provider")
raise ValueError(str(e))
@classmethod
def update_trigger_provider(
cls,
tenant_id: str,
credential_id: str,
credentials: Optional[dict] = None,
name: Optional[str] = None,
) -> dict:
"""
Update an existing trigger provider's credentials or name.
:param tenant_id: Tenant ID
:param credential_id: Credential instance ID
:param credentials: New credentials (optional)
:param name: New name (optional)
:return: Success response
"""
with Session(db.engine) as session:
# Get provider
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
try:
# Update credentials if provided
if credentials:
encrypter, cache = cls._create_provider_encrypter(
tenant_id=tenant_id,
provider=db_provider,
)
# Handle hidden values
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
cache.delete()
# Update name if provided
if name and name != db_provider.name:
# Check if name already exists
existing = (
session.query(TriggerProvider)
.filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name)
.filter(TriggerProvider.id != credential_id)
.first()
)
if existing:
raise ValueError(f"Credential name '{name}' already exists")
db_provider.name = name
session.commit()
return {"result": "success"}
except Exception as e:
session.rollback()
raise ValueError(str(e))
@classmethod
def delete_trigger_provider(cls, tenant_id: str, credential_id: str) -> dict:
"""
Delete a trigger provider credential.
:param tenant_id: Tenant ID
:param credential_id: Credential instance ID
:return: Success response
"""
with Session(db.engine) as session:
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
# Delete provider
session.delete(db_provider)
session.commit()
# Clear cache
_, cache = cls._create_provider_encrypter(tenant_id, db_provider)
cache.delete()
return {"result": "success"}
@classmethod
def refresh_oauth_token(
cls,
tenant_id: str,
credential_id: str,
) -> dict:
"""
Refresh OAuth token for a trigger provider.
:param tenant_id: Tenant ID
:param credential_id: Credential instance ID
:return: New token info
"""
with Session(db.engine) as session:
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
if db_provider.credential_type != CredentialType.OAUTH2.value:
raise ValueError("Only OAuth credentials can be refreshed")
# Parse provider ID
provider_id = TriggerProviderID(db_provider.provider_id)
# Create encrypter
encrypter, cache = cls._create_provider_encrypter(
tenant_id=tenant_id,
provider=db_provider,
)
# Decrypt current credentials
current_credentials = encrypter.decrypt(db_provider.credentials)
# Get OAuth client configuration
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback"
)
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
# Refresh token
oauth_handler = OAuthHandler()
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=db_provider.user_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=current_credentials,
)
# Update credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(dict(refreshed_credentials.credentials)))
db_provider.expires_at = refreshed_credentials.expires_at
session.commit()
# Clear cache
cache.delete()
return {
"result": "success",
"expires_at": refreshed_credentials.expires_at,
}
@classmethod
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]:
"""
Get OAuth client configuration for a provider.
First tries tenant-level OAuth, then falls back to system OAuth.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: OAuth client configuration or None
"""
# Get trigger provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
# Create encrypter for OAuth client params
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine, autoflush=False) as session:
# First check for tenant-specific OAuth client
tenant_client: TriggerOAuthTenantClient | None = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
enabled=True,
)
.first()
)
oauth_params: Mapping[str, Any] | None = None
if tenant_client:
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
return oauth_params
# Only verified plugins can use system OAuth client
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
if not is_verified:
return oauth_params
# Check for system-level OAuth client
system_client: TriggerOAuthSystemClient | None = (
session.query(TriggerOAuthSystemClient)
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
.first()
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")
return oauth_params
@classmethod
def save_custom_oauth_client_params(
cls,
tenant_id: str,
provider_id: TriggerProviderID,
client_params: Optional[dict] = None,
enabled: Optional[bool] = None,
) -> dict:
"""
Save or update custom OAuth client parameters for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:param client_params: OAuth client parameters (client_id, client_secret, etc.)
:param enabled: Enable/disable the custom OAuth client
:return: Success response
"""
if client_params is None and enabled is None:
return {"result": "success"}
# Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
with Session(db.engine) as session:
# Find existing custom client params
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
.first()
)
# Create new record if doesn't exist
if custom_client is None:
custom_client = TriggerOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
session.add(custom_client)
# Update client params if provided
if client_params is not None:
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
# Handle hidden values
original_params = encrypter.decrypt(custom_client.oauth_params)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
# Update enabled status if provided
if enabled is not None:
custom_client.enabled = enabled
session.commit()
return {"result": "success"}
@classmethod
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
"""
Get custom OAuth client parameters for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: Masked OAuth client parameters
"""
with Session(db.engine) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
.first()
)
if custom_client is None:
return {}
# Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
# Create encrypter to decrypt and mask values
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params))
@classmethod
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
"""
Delete custom OAuth client parameters for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: Success response
"""
with Session(db.engine) as session:
session.query(TriggerOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@classmethod
def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
"""
Check if custom OAuth client is enabled for a trigger provider.
:param tenant_id: Tenant ID
:param provider_id: Provider identifier
:return: True if enabled, False otherwise
"""
with Session(db.engine, autoflush=False) as session:
custom_client = (
session.query(TriggerOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
enabled=True,
)
.first()
)
return custom_client is not None
@classmethod
def create_oauth_proxy_context(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
) -> str:
"""
Create OAuth proxy context for authorization flow.
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider: Provider identifier
:return: Context ID for OAuth flow
"""
return OAuthProxyService.create_proxy_context(
user_id=user_id,
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
@classmethod
def _create_provider_encrypter(
cls, tenant_id: str, provider: TriggerProvider
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
Create encrypter and cache for trigger provider credentials
:param tenant_id: Tenant ID
:param provider: TriggerProvider instance
:return: Tuple of encrypter and cache
"""
from core.helper.provider_cache import TriggerProviderCredentialCache
# Parse provider ID
provider_id = TriggerProviderID(provider.provider_id)
# Get trigger provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
# Create encrypter with appropriate schema based on credential type
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(provider.credential_type)
],
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider=provider.provider_id,
credential_id=provider.id,
),
)
return encrypter, cache
@classmethod
def _generate_provider_name(
cls,
session: Session,
tenant_id: str,
provider_id: TriggerProviderID,
credential_type: CredentialType,
) -> str:
"""
Generate a unique name for a provider credential instance.
:param session: Database session
:param tenant_id: Tenant ID
:param provider: Provider identifier
:param credential_type: Credential type
:return: Generated name
"""
try:
db_providers = (
session.query(TriggerProvider)
.filter_by(
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type.value,
)
.order_by(desc(TriggerProvider.created_at))
.all()
)
# Get base name
base_name = credential_type.get_name()
# Find existing numbered names
pattern = rf"^{re.escape(base_name)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# Generate next number
if not numbers:
return f"{base_name} 1"
max_number = max(numbers)
return f"{base_name} {max_number + 1}"
except Exception as e:
logger.warning("Error generating provider name")
return f"{credential_type.get_name()} 1"

View File

@ -0,0 +1,23 @@
import logging
from typing import Any
from flask import Request, Response
logger = logging.getLogger(__name__)
class TriggerService:
__MAX_REQUEST_LOG_COUNT__ = 10
@classmethod
def process_webhook(cls, webhook_id: str, request: Request) -> Response:
"""Extract and process data from incoming webhook request."""
# TODO redis slidingwindow log, save the recent request log in redis, rollover the log when the window is full
# TODO find the trigger subscription
# TODO fetch the trigger controller
# TODO dispatch by the trigger controller
# TODO using the dispatch result(events) to invoke the trigger events

View File

@ -55,9 +55,7 @@ class BaseQueueDispatcher(ABC):
True if quota available, False otherwise
"""
# Check without consuming
remaining = self.rate_limiter.get_remaining_quota(
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit()
)
remaining = self.rate_limiter.get_remaining_quota(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
return remaining > 0
def consume_quota(self, tenant_id: str) -> bool:
@ -70,9 +68,7 @@ class BaseQueueDispatcher(ABC):
Returns:
True if quota consumed successfully, False if limit reached
"""
return self.rate_limiter.check_and_consume(
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit()
)
return self.rate_limiter.check_and_consume(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
class ProfessionalQueueDispatcher(BaseQueueDispatcher):

View File

@ -74,10 +74,10 @@ class TenantDailyRateLimiter:
Number of seconds until UTC midnight
"""
utc_now = datetime.utcnow()
# Get next midnight in UTC
next_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
return int((next_midnight - utc_now).total_seconds())
def check_and_consume(self, tenant_id: str, max_daily_limit: int) -> bool:
@ -174,9 +174,9 @@ class TenantDailyRateLimiter:
"""
tz = pytz.timezone(timezone_str)
utc_now = datetime.utcnow()
# Get next midnight in UTC, then convert to tenant's timezone
next_utc_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
next_utc_midnight = pytz.UTC.localize(next_utc_midnight)
return next_utc_midnight.astimezone(tz)