feat(trigger): implement complete OAuth authorization flow for trigger providers

- Add OAuth authorization URL generation API endpoint
- Implement OAuth callback handler for credential storage
- Support both system-level and tenant-level OAuth clients
- Add trigger provider credential encryption utilities
- Refactor trigger entities into separate modules
- Update trigger provider service with OAuth client management
- Add credential cache for trigger providers

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Harry 2025-08-28 15:20:15 +08:00
parent 87120ad4ac
commit a46c9238fa
13 changed files with 420 additions and 456 deletions

View File

@ -1,20 +1,37 @@
import logging
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from libs.login import current_user, login_required
from models.account import Account
from services.plugin.oauth_service import OAuthProxyService
from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__)
class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
class TriggerProviderCredentialListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@ -27,8 +44,10 @@ class TriggerProviderListApi(Resource):
raise Forbidden()
try:
return TriggerProviderService.list_trigger_providers(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
return jsonable_encoder(
TriggerProviderService.list_trigger_provider_credentials(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
)
except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e)
@ -145,30 +164,128 @@ class TriggerProviderOAuthAuthorizeApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
"""Initiate OAuth authorization flow for a provider"""
"""Initiate OAuth authorization flow for a trigger provider"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
try:
context_id = TriggerProviderService.create_oauth_proxy_context(
tenant_id=user.current_tenant_id,
user_id=user.id,
provider_id=TriggerProviderID(provider),
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
tenant_id = user.current_tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
# TODO: Build OAuth authorization URL
# This will be implemented when we have provider-specific OAuth configs
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
return {
"context_id": context_id,
"authorization_url": f"/oauth/authorize?context={context_id}",
}
# Create OAuth handler and proxy context
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
)
# Build redirect URI for callback
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
# Get authorization URL
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
# Create response with cookie
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
except Exception as e:
logger.exception("Error initiating OAuth flow", exc_info=e)
raise
class TriggerProviderOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
# Use and validate proxy context
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
# Parse provider ID
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Get OAuth credentials from callback
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials:
raise Exception("Failed to get OAuth credentials")
# Save OAuth credentials to database
TriggerProviderService.add_trigger_provider(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
credential_type=CredentialType.OAUTH2,
credentials=dict(credentials),
expires_at=expires_at,
)
# Redirect to OAuth callback page
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerProviderOAuthRefreshTokenApi(Resource):
@setup_required
@login_required
@ -257,16 +374,13 @@ class TriggerProviderOAuthClientManageApi(Resource):
try:
provider_id = TriggerProviderID(provider)
result = TriggerProviderService.save_custom_oauth_client_params(
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
@ -287,13 +401,10 @@ class TriggerProviderOAuthClientManageApi(Resource):
try:
provider_id = TriggerProviderID(provider)
result = TriggerProviderService.delete_custom_oauth_client_params(
return TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
return result
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:
@ -302,8 +413,12 @@ class TriggerProviderOAuthClientManageApi(Resource):
# Trigger provider endpoints
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-provider/<path:provider>/list")
api.add_resource(TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/<path:provider>/add")
api.add_resource(
TriggerProviderCredentialListApi, "/workspaces/current/trigger-provider/credentials/<path:provider>/list"
)
api.add_resource(
TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/credentials/<path:provider>/add"
)
api.add_resource(
TriggerProviderCredentialsUpdateApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/update"
)
@ -311,9 +426,11 @@ api.add_resource(
TriggerProviderCredentialsDeleteApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/delete"
)
# OAuth
api.add_resource(
TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
)
api.add_resource(TriggerProviderOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(
TriggerProviderOAuthRefreshTokenApi,
"/workspaces/current/trigger-provider/credentials/<path:credential_id>/oauth/refresh",

View File

@ -71,15 +71,25 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
class TriggerProviderCredentialCache(ProviderCredentialsCache):
"""Cache for trigger provider credentials"""
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]
provider_id = kwargs["provider_id"]
credential_id = kwargs["credential_id"]
return f"trigger_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
class TriggerProviderOAuthClientCache(ProviderCredentialsCache):
"""Cache for trigger provider OAuth client"""
def __init__(self, tenant_id: str, provider_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""

View File

@ -14,7 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities import TriggerProviderEntity
from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))

View File

@ -1,7 +1,7 @@
from typing import Any
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity, PluginTriggerProviderEntity
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import PluginTriggerProviderEntity
from core.plugin.impl.base import BasePluginClient
@ -15,15 +15,15 @@ class PluginTriggerManager(BasePluginClient):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
for trigger in declaration.get("triggers", []):
trigger["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tools",
list[PluginToolProviderEntity],
f"plugin/{tenant_id}/management/triggers",
list[PluginTriggerProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
@ -32,37 +32,36 @@ class PluginTriggerManager(BasePluginClient):
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.tools:
tool.identity.provider = provider.declaration.identity.name
for trigger in provider.declaration.triggers:
trigger.identity.provider = provider.declaration.identity.name
return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
"""
tool_provider_id = ToolProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for tool in data.get("declaration", {}).get("tools", []):
tool["identity"]["provider"] = tool_provider_id.provider_name
for trigger in data.get("declaration", {}).get("triggers", []):
trigger["identity"]["provider"] = provider_id.provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tool",
PluginToolProviderEntity,
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
f"plugin/{tenant_id}/management/trigger",
PluginTriggerProviderEntity,
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in response.declaration.tools:
tool.identity.provider = response.declaration.identity.name
# override the provider name for each trigger to plugin_id/provider_name
for trigger in response.declaration.triggers:
trigger.identity.provider = response.declaration.identity.name
return response

View File

@ -122,7 +122,6 @@ class ProviderConfigEncrypter:
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import (
OAuthSchema,
TriggerDescription,
TriggerEntity,
TriggerParameter,
TriggerProviderIdentity,
)
class TriggerProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider id of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
credentials: dict = Field(description="The credentials of the credential")
class TriggerProviderApiEntity(BaseModel):
identity: TriggerProviderIdentity = Field(description="The identity of the trigger provider")
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
oauth_schema: Optional[OAuthSchema] = Field(description="The OAuth schema of the trigger provider")
subscription_schema: list[ProviderConfig] = Field(description="The subscription schema of the trigger provider")
triggers: list[TriggerEntity] = Field(description="The triggers of the trigger provider")
class TriggerApiEntity(BaseModel):
name: str = Field(description="The name of the trigger")
description: TriggerDescription = Field(description="The description of the trigger")
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderCredentialApiEntity"]

View File

@ -4,18 +4,11 @@ from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.parameters import PluginParameterAutoGenerate, PluginParameterOption, PluginParameterTemplate
from core.tools.entities.common_entities import I18nObject
class TriggerParameterOption(BaseModel):
"""
The option of the trigger parameter
"""
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
class TriggerParameterType(StrEnum):
"""The type of the parameter"""
@ -32,20 +25,6 @@ class TriggerParameterType(StrEnum):
DYNAMIC_SELECT = "dynamic-select"
class ParameterAutoGenerate(BaseModel):
"""Auto generation configuration for parameters"""
enabled: bool = Field(default=False, description="Whether auto generation is enabled")
template: Optional[str] = Field(default=None, description="Template for auto generation")
class ParameterTemplate(BaseModel):
"""Template configuration for parameters"""
value: str = Field(..., description="Template value")
type: str = Field(default="jinja2", description="Template type")
class TriggerParameter(BaseModel):
"""
The parameter of the trigger
@ -54,17 +33,17 @@ class TriggerParameter(BaseModel):
name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user")
type: TriggerParameterType = Field(..., description="The type of the parameter")
auto_generate: Optional[ParameterAutoGenerate] = Field(
auto_generate: Optional[PluginParameterAutoGenerate] = Field(
default=None, description="The auto generate of the parameter"
)
template: Optional[ParameterTemplate] = Field(default=None, description="The template of the parameter")
template: Optional[PluginParameterTemplate] = Field(default=None, description="The template of the parameter")
scope: Optional[str] = None
required: Optional[bool] = False
default: Union[int, float, str, None] = None
min: Union[float, int, None] = None
max: Union[float, int, None] = None
precision: Optional[int] = None
options: Optional[list[TriggerParameterOption]] = None
options: Optional[list[PluginParameterOption]] = None
description: Optional[I18nObject] = None
@ -89,7 +68,7 @@ class TriggerIdentity(BaseModel):
author: str = Field(..., description="The author of the trigger")
name: str = Field(..., description="The name of the trigger")
label: I18nObject = Field(..., description="The label of the trigger")
provider: str = Field(..., description="The provider of the trigger")
class TriggerDescription(BaseModel):
"""
@ -100,69 +79,23 @@ class TriggerDescription(BaseModel):
llm: I18nObject = Field(..., description="LLM readable description")
class TriggerConfigurationExtraPython(BaseModel):
"""Python configuration for trigger"""
source: str = Field(..., description="The source file path for the trigger implementation")
class TriggerConfigurationExtra(BaseModel):
"""
The extra configuration for trigger
"""
class TriggerEntity(BaseModel):
"""
The configuration of a trigger
"""
python: TriggerConfigurationExtraPython
identity: TriggerIdentity = Field(..., description="The identity of the trigger")
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
description: TriggerDescription = Field(..., description="The description of the trigger")
extra: TriggerConfigurationExtra = Field(..., description="The extra configuration of the trigger")
output_schema: Optional[Mapping[str, Any]] = Field(
default=None, description="The output schema that this trigger produces"
)
class TriggerProviderConfigurationExtraPython(BaseModel):
"""Python configuration for trigger provider"""
source: str = Field(..., description="The source file path for the trigger provider implementation")
class TriggerProviderConfigurationExtra(BaseModel):
"""
The extra configuration for trigger provider
"""
python: TriggerProviderConfigurationExtraPython
class OAuthSchema(BaseModel):
"""OAuth configuration schema"""
authorization_url: str = Field(..., description="OAuth authorization URL")
token_url: str = Field(..., description="OAuth token URL")
client_id: str = Field(..., description="OAuth client ID")
client_secret: str = Field(..., description="OAuth client secret")
redirect_uri: Optional[str] = Field(default=None, description="OAuth redirect URI")
scope: Optional[str] = Field(default=None, description="OAuth scope")
class ProviderConfig(BaseModel):
"""Provider configuration item"""
name: str = Field(..., description="Configuration field name")
type: str = Field(..., description="Configuration field type")
required: bool = Field(default=False, description="Whether this field is required")
default: Any = Field(default=None, description="Default value")
label: Optional[I18nObject] = Field(default=None, description="Field label")
description: Optional[I18nObject] = Field(default=None, description="Field description")
options: Optional[list[dict[str, Any]]] = Field(default=None, description="Options for select type")
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class TriggerProviderEntity(BaseModel):
"""
@ -183,7 +116,6 @@ class TriggerProviderEntity(BaseModel):
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
)
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
extra: TriggerProviderConfigurationExtra = Field(..., description="The extra configuration of the trigger provider")
class Subscription(BaseModel):
@ -223,21 +155,12 @@ class Unsubscription(BaseModel):
# Export all entities
__all__ = [
"OAuthSchema",
"ParameterAutoGenerate",
"ParameterTemplate",
"ProviderConfig",
"Subscription",
"TriggerConfigurationExtra",
"TriggerConfigurationExtraPython",
"TriggerDescription",
"TriggerEntity",
"TriggerEntity",
"TriggerIdentity",
"TriggerParameter",
"TriggerParameterOption",
"TriggerParameterType",
"TriggerProviderConfigurationExtra",
"TriggerProviderConfigurationExtraPython",
"TriggerProviderEntity",
"TriggerProviderIdentity",
"Unsubscription",

View File

@ -6,8 +6,11 @@ import logging
import time
from typing import Optional
from core.entities.provider_entities import BasicProviderConfig
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities import (
from core.trigger.entities.api_entities import TriggerProviderApiEntity
from core.trigger.entities.entities import (
ProviderConfig,
Subscription,
TriggerEntity,
@ -19,7 +22,7 @@ from core.trigger.entities import (
logger = logging.getLogger(__name__)
class TriggerProviderController:
class PluginTriggerProviderController:
"""
Controller for plugin trigger providers
"""
@ -44,6 +47,18 @@ class TriggerProviderController:
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
def get_provider_id(self) -> TriggerProviderID:
"""
Get provider ID
"""
return TriggerProviderID(f"{self.plugin_id}/{self.entity.identity.name}")
def to_api_entity(self) -> TriggerProviderApiEntity:
"""
Convert to API entity
"""
return TriggerProviderApiEntity(**self.entity.model_dump())
@property
def identity(self) -> TriggerProviderIdentity:
"""Get provider identity"""
@ -69,14 +84,6 @@ class TriggerProviderController:
return trigger
return None
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
Get credentials schema for this provider
:return: List of provider config schemas
"""
return self.entity.credentials_schema
def get_subscription_schema(self) -> list[ProviderConfig]:
"""
Get subscription schema for this provider
@ -109,18 +116,24 @@ class TriggerProviderController:
types.append(CredentialType.API_KEY)
return types
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
def get_credentials_schema(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
"""
Get credentials schema by credential type
:param credential_type: The type of credential (oauth or api_key)
:return: List of provider config schemas
"""
if credential_type == CredentialType.OAUTH2.value:
credential_type = CredentialType.of(credential_type) if isinstance(credential_type, str) else credential_type
if credential_type == CredentialType.OAUTH2:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
"""
Get credential schema config by credential type
"""
return [x.to_basic_provider_config() for x in self.get_credentials_schema(credential_type)]
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
@ -183,17 +196,5 @@ class TriggerProviderController:
logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id)
return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}")
def handle_webhook(self, webhook_path: str, request_data: dict, credentials: dict) -> dict:
"""
Handle incoming webhook through plugin runtime
:param webhook_path: Webhook path
:param request_data: Request data
:param credentials: Provider credentials
:return: Webhook handling result
"""
logger.info("Handling webhook for path %s for plugin %s", webhook_path, self.plugin_id)
return {"success": True, "path": webhook_path, "plugin": self.plugin_id, "data_received": request_data}
__all__ = ["TriggerProviderController"]
__all__ = ["PluginTriggerProviderController"]

View File

@ -5,11 +5,14 @@ Trigger Manager for loading and managing trigger providers and triggers
import logging
from typing import Optional
from core.trigger.entities import (
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.impl.trigger import PluginTriggerManager
from core.trigger.entities.entities import (
ProviderConfig,
Subscription,
TriggerEntity,
Unsubscription,
)
from core.trigger.plugin_trigger import PluginTriggerController
from core.trigger.provider import PluginTriggerProviderController
logger = logging.getLogger(__name__)
@ -28,7 +31,7 @@ class TriggerManager:
:param tenant_id: Tenant ID
:return: List of trigger provider controllers
"""
manager = PluginTriggerController()
manager = PluginTriggerManager()
provider_entities = manager.fetch_trigger_providers(tenant_id)
controllers = []
@ -48,22 +51,21 @@ class TriggerManager:
return controllers
@classmethod
def get_plugin_trigger_provider(
cls, tenant_id: str, plugin_id: str, provider_name: str
) -> Optional[PluginTriggerProviderController]:
def get_trigger_provider(
cls, tenant_id: str, provider_id: TriggerProviderID
) -> PluginTriggerProviderController:
"""
Get a specific plugin trigger provider
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:return: Trigger provider controller or None
"""
manager = PluginTriggerManager()
provider = manager.fetch_trigger_provider(tenant_id, plugin_id, provider_name)
provider = manager.fetch_trigger_provider(tenant_id, provider_id)
if not provider:
return None
raise ValueError(f"Trigger provider {provider_id} not found")
try:
return PluginTriggerProviderController(
@ -74,287 +76,139 @@ class TriggerManager:
)
except Exception as e:
logger.exception("Failed to load trigger provider")
return None
raise e
@classmethod
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
"""
List all trigger providers (plugin and builtin)
List all trigger providers (plugin)
:param tenant_id: Tenant ID
:return: List of all trigger provider controllers
"""
providers = []
# Get plugin providers
plugin_providers = cls.list_plugin_trigger_providers(tenant_id)
providers.extend(plugin_providers)
# TODO: Add builtin providers when implemented
# builtin_providers = cls.list_builtin_trigger_providers(tenant_id)
# providers.extend(builtin_providers)
return providers
return cls.list_plugin_trigger_providers(tenant_id)
@classmethod
def list_triggers_by_provider(cls, tenant_id: str, plugin_id: str, provider_name: str) -> list[TriggerEntity]:
def list_triggers_by_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerEntity]:
"""
List all triggers for a specific provider
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:return: List of trigger entities
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return []
provider = cls.get_trigger_provider(tenant_id, provider_id)
return provider.get_triggers()
@classmethod
def get_trigger(
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str
) -> Optional[TriggerEntity]:
"""
Get a specific trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:param trigger_name: Trigger name
:return: Trigger entity or None
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return None
return provider.get_trigger(trigger_name)
return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
@classmethod
def validate_trigger_credentials(
cls, tenant_id: str, plugin_id: str, provider_name: str, credentials: dict
cls, tenant_id: str, provider_id: TriggerProviderID, credentials: dict
) -> tuple[bool, str]:
"""
Validate trigger provider credentials
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:param credentials: Credentials to validate
:return: Tuple of (is_valid, error_message)
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return False, "Provider not found"
try:
provider.validate_credentials(credentials)
cls.get_trigger_provider(tenant_id, provider_id).validate_credentials(credentials)
return True, ""
except Exception as e:
return False, str(e)
@classmethod
def execute_trigger(
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str, parameters: dict, credentials: dict
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str, parameters: dict, credentials: dict
) -> dict:
"""
Execute a trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:param trigger_name: Trigger name
:param parameters: Trigger parameters
:param credentials: Provider credentials
:return: Trigger execution result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
trigger = provider.get_trigger(trigger_name)
trigger = cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
if not trigger:
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_name}")
return provider.execute_trigger(trigger_name, parameters, credentials)
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
return cls.get_trigger_provider(tenant_id, provider_id).execute_trigger(trigger_name, parameters, credentials)
@classmethod
def subscribe_trigger(
cls,
tenant_id: str,
plugin_id: str,
provider_name: str,
provider_id: TriggerProviderID,
trigger_name: str,
subscription_params: dict,
credentials: dict,
) -> dict:
) -> Subscription:
"""
Subscribe to a trigger (e.g., register webhook)
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:param trigger_name: Trigger name
:param subscription_params: Subscription parameters
:param credentials: Provider credentials
:return: Subscription result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.subscribe_trigger(trigger_name, subscription_params, credentials)
return cls.get_trigger_provider(tenant_id, provider_id).subscribe_trigger(
trigger_name, subscription_params, credentials
)
@classmethod
def unsubscribe_trigger(
cls,
tenant_id: str,
plugin_id: str,
provider_name: str,
provider_id: TriggerProviderID,
trigger_name: str,
subscription_metadata: dict,
credentials: dict,
) -> dict:
) -> Unsubscription:
"""
Unsubscribe from a trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:param trigger_name: Trigger name
:param subscription_metadata: Subscription metadata from subscribe operation
:param credentials: Provider credentials
:return: Unsubscription result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.unsubscribe_trigger(trigger_name, subscription_metadata, credentials)
@classmethod
def handle_webhook(
cls,
tenant_id: str,
plugin_id: str,
provider_name: str,
webhook_path: str,
request_data: dict,
credentials: dict,
) -> dict:
"""
Handle incoming webhook for a trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param webhook_path: Webhook path
:param request_data: Webhook request data
:param credentials: Provider credentials
:return: Webhook handling result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.handle_webhook(webhook_path, request_data, credentials)
@classmethod
def get_provider_credentials_schema(
cls, tenant_id: str, plugin_id: str, provider_name: str
) -> list[ProviderConfig]:
"""
Get provider credentials schema
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: List of provider config schemas
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return []
return provider.get_credentials_schema()
return cls.get_trigger_provider(tenant_id, provider_id).unsubscribe_trigger(
trigger_name, subscription_metadata, credentials
)
@classmethod
def get_provider_subscription_schema(
cls, tenant_id: str, plugin_id: str, provider_name: str
cls, tenant_id: str, provider_id: TriggerProviderID
) -> list[ProviderConfig]:
"""
Get provider subscription schema
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param provider_id: Provider ID
:return: List of subscription config schemas
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return []
return provider.get_subscription_schema()
@classmethod
def get_provider_info(cls, tenant_id: str, plugin_id: str, provider_name: str) -> Optional[dict]:
"""
Get provider information
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: Provider info dict or None
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return None
return {
"plugin_id": plugin_id,
"provider_name": provider_name,
"identity": provider.entity.identity.model_dump() if provider.entity.identity else {},
"credentials_schema": [c.model_dump() for c in provider.entity.credentials_schema],
"subscription_schema": [s.model_dump() for s in provider.entity.subscription_schema],
"oauth_enabled": provider.entity.oauth_schema is not None,
"trigger_count": len(provider.entity.triggers),
"triggers": [t.identity.model_dump() for t in provider.entity.triggers],
}
@classmethod
def list_providers_for_workflow(cls, tenant_id: str) -> list[dict]:
"""
List trigger providers suitable for workflow usage
:param tenant_id: Tenant ID
:return: List of provider info dicts
"""
providers = cls.list_all_trigger_providers(tenant_id)
result = []
for provider in providers:
info = {
"plugin_id": provider.plugin_id,
"provider_name": provider.entity.identity.name,
"label": provider.entity.identity.label.model_dump(),
"description": provider.entity.identity.description.model_dump(),
"icon": provider.entity.identity.icon,
"trigger_count": len(provider.entity.triggers),
}
result.append(info)
return result
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
# Export
__all__ = ["TriggerManager"]

View File

@ -0,0 +1,49 @@
from core.helper.provider_cache import TriggerProviderCredentialCache, TriggerProviderOAuthClientCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
from core.trigger.provider import PluginTriggerProviderController
from models.trigger import TriggerProvider
def create_trigger_provider_encrypter_for_credential(
tenant_id: str,
controller: PluginTriggerProviderController,
credential: TriggerProvider | TriggerProviderCredentialApiEntity,
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(credential.credential_type),
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential.id,
),
)
def create_trigger_provider_encrypter(
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(credential_type),
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential_id,
),
)
def create_trigger_provider_oauth_encrypter(
tenant_id: str, controller: PluginTriggerProviderController
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
cache=TriggerProviderOAuthClientCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
),
)

View File

@ -1,5 +1,6 @@
import json
from datetime import UTC, datetime
import time
from datetime import datetime
from typing import cast
import sqlalchemy as sa
@ -7,6 +8,7 @@ from sqlalchemy import DateTime, Index, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
from models.base import Base
from models.types import StringUUID
@ -45,20 +47,23 @@ class TriggerProvider(Base):
except (json.JSONDecodeError, TypeError):
return {}
@property
def credentials_str(self) -> str:
"""Get credentials as string"""
return self.encrypted_credentials or "{}"
def is_oauth_expired(self) -> bool:
"""Check if OAuth token is expired"""
if self.credential_type != CredentialType.OAUTH2.value:
return False
if self.expires_at == -1:
return False
# Check if token expires in next 60 seconds
return (self.expires_at - 60) < int(datetime.now(UTC).timestamp())
# Check if token expires in next 3 minutes
return (self.expires_at - 180) < int(time.time())
def to_api_entity(self) -> TriggerProviderCredentialApiEntity:
return TriggerProviderCredentialApiEntity(
id=self.id,
name=self.name,
provider=self.provider_id,
credential_type=CredentialType(self.credential_type),
credentials=self.credentials,
)
# system level trigger oauth client params
class TriggerOAuthSystemClient(Base):

View File

@ -13,14 +13,20 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.service import PluginService
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from core.tools.utils.encryption import (
create_provider_encrypter,
)
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import (
create_trigger_provider_encrypter_for_credential,
create_trigger_provider_oauth_encrypter,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider
from services.plugin.oauth_service import OAuthProxyService
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
@ -31,13 +37,34 @@ class TriggerProviderService:
__MAX_TRIGGER_PROVIDER_COUNT__ = 100
@classmethod
def list_trigger_providers(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerProvider]:
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
"""List all trigger providers for the current tenant"""
# TODO fetch trigger plugin controller
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
# TODO fetch all trigger plugin credentials
@classmethod
def list_trigger_provider_credentials(
cls, tenant_id: str, provider_id: TriggerProviderID
) -> list[TriggerProviderCredentialApiEntity]:
"""List all trigger providers for the current tenant"""
credentials: list[TriggerProviderCredentialApiEntity] = []
with Session(db.engine, autoflush=False) as session:
return session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).all()
credentials_db = (
session.query(TriggerProvider)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
.order_by(desc(TriggerProvider.created_at))
.all()
)
credentials = [credential.to_api_entity() for credential in credentials_db]
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for credential in credentials:
encrypter, _ = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id,
controller=provider_controller,
credential=credential,
)
credential.credentials = encrypter.decrypt(credential.credentials)
return credentials
@classmethod
def add_trigger_provider(
@ -63,6 +90,7 @@ class TriggerProviderService:
:return: Success response
"""
try:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine) as session:
# Use distributed lock to prevent race conditions
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
@ -96,10 +124,9 @@ class TriggerProviderService:
if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider")
# Create encrypter for credentials
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[], # We'll define schema later in TriggerProvider classes
config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(),
)
@ -141,20 +168,21 @@ class TriggerProviderService:
:return: Success response
"""
with Session(db.engine) as session:
# Get provider
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
try:
# Update credentials if provided
if credentials:
encrypter, cache = cls._create_provider_encrypter(
tenant_id=tenant_id,
provider=db_provider,
)
provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(db_provider.provider_id)
)
if credentials:
encrypter, cache = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id,
controller=provider_controller,
credential=db_provider,
)
# Handle hidden values
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials = {
@ -200,14 +228,20 @@ class TriggerProviderService:
if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found")
# Delete provider
provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(db_provider.provider_id)
)
# Clear cache
_, cache = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id,
controller=provider_controller,
credential=db_provider,
)
session.delete(db_provider)
session.commit()
# Clear cache
_, cache = cls._create_provider_encrypter(tenant_id, db_provider)
cache.delete()
return {"result": "success"}
@classmethod
@ -232,13 +266,13 @@ class TriggerProviderService:
if db_provider.credential_type != CredentialType.OAUTH2.value:
raise ValueError("Only OAuth credentials can be refreshed")
# Parse provider ID
provider_id = TriggerProviderID(db_provider.provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Create encrypter
encrypter, cache = cls._create_provider_encrypter(
encrypter, cache = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id,
provider=db_provider,
controller=provider_controller,
credential=db_provider,
)
# Decrypt current credentials
@ -285,18 +319,8 @@ class TriggerProviderService:
:param provider_id: Provider identifier
:return: OAuth client configuration or None
"""
# Get trigger provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
# Create encrypter for OAuth client params
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine, autoflush=False) as session:
# First check for tenant-specific OAuth client
tenant_client: TriggerOAuthTenantClient | None = (
session.query(TriggerOAuthTenantClient)
.filter_by(
@ -310,10 +334,10 @@ class TriggerProviderService:
oauth_params: Mapping[str, Any] | None = None
if tenant_client:
encrypter, _ = create_trigger_provider_oauth_encrypter(tenant_id, provider_controller)
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
return oauth_params
# Only verified plugins can use system OAuth client
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
if not is_verified:
return oauth_params
@ -354,7 +378,7 @@ class TriggerProviderService:
return {"result": "success"}
# Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine) as session:
# Find existing custom client params
@ -425,7 +449,7 @@ class TriggerProviderService:
return {}
# Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Create encrypter to decrypt and mask values
encrypter, _ = create_provider_encrypter(
@ -477,63 +501,6 @@ class TriggerProviderService:
)
return custom_client is not None
@classmethod
def create_oauth_proxy_context(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
) -> str:
"""
Create OAuth proxy context for authorization flow.
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider: Provider identifier
:return: Context ID for OAuth flow
"""
return OAuthProxyService.create_proxy_context(
user_id=user_id,
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
@classmethod
def _create_provider_encrypter(
cls, tenant_id: str, provider: TriggerProvider
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
Create encrypter and cache for trigger provider credentials
:param tenant_id: Tenant ID
:param provider: TriggerProvider instance
:return: Tuple of encrypter and cache
"""
from core.helper.provider_cache import TriggerProviderCredentialCache
# Parse provider ID
provider_id = TriggerProviderID(provider.provider_id)
# Get trigger provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
# Create encrypter with appropriate schema based on credential type
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(provider.credential_type)
],
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider=provider.provider_id,
credential_id=provider.id,
),
)
return encrypter, cache
@classmethod
def _generate_provider_name(
cls,

View File

@ -1,5 +1,4 @@
import logging
from typing import Any
from flask import Request, Response
@ -21,3 +20,4 @@ class TriggerService:
# TODO dispatch by the trigger controller
# TODO using the dispatch result(events) to invoke the trigger events
raise NotImplementedError("Not implemented")