feat(trigger): implement complete OAuth authorization flow for trigger providers

- Add OAuth authorization URL generation API endpoint
- Implement OAuth callback handler for credential storage
- Support both system-level and tenant-level OAuth clients
- Add trigger provider credential encryption utilities
- Refactor trigger entities into separate modules
- Update trigger provider service with OAuth client management
- Add credential cache for trigger providers

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Harry 2025-08-28 15:20:15 +08:00
parent 87120ad4ac
commit a46c9238fa
13 changed files with 420 additions and 456 deletions

View File

@ -1,20 +1,37 @@
import logging import logging
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, Forbidden from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.account import Account from models.account import Account
from services.plugin.oauth_service import OAuthProxyService
from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_provider_service import TriggerProviderService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TriggerProviderListApi(Resource): class TriggerProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List all trigger providers for the current tenant"""
user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
class TriggerProviderCredentialListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -27,8 +44,10 @@ class TriggerProviderListApi(Resource):
raise Forbidden() raise Forbidden()
try: try:
return TriggerProviderService.list_trigger_providers( return jsonable_encoder(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider) TriggerProviderService.list_trigger_provider_credentials(
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
)
) )
except Exception as e: except Exception as e:
logger.exception("Error listing trigger providers", exc_info=e) logger.exception("Error listing trigger providers", exc_info=e)
@ -145,30 +164,128 @@ class TriggerProviderOAuthAuthorizeApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
"""Initiate OAuth authorization flow for a provider""" """Initiate OAuth authorization flow for a trigger provider"""
user = current_user user = current_user
assert isinstance(user, Account) assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
try: try:
context_id = TriggerProviderService.create_oauth_proxy_context( provider_id = TriggerProviderID(provider)
tenant_id=user.current_tenant_id, plugin_id = provider_id.plugin_id
user_id=user.id, provider_name = provider_id.provider_name
provider_id=TriggerProviderID(provider), tenant_id = user.current_tenant_id
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
) )
# TODO: Build OAuth authorization URL if oauth_client_params is None:
# This will be implemented when we have provider-specific OAuth configs raise Forbidden("No OAuth client configuration found for this trigger provider")
return { # Create OAuth handler and proxy context
"context_id": context_id, oauth_handler = OAuthHandler()
"authorization_url": f"/oauth/authorize?context={context_id}", context_id = OAuthProxyService.create_proxy_context(
} user_id=user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
)
# Build redirect URI for callback
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
# Get authorization URL
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
# Create response with cookie
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
except Exception as e: except Exception as e:
logger.exception("Error initiating OAuth flow", exc_info=e) logger.exception("Error initiating OAuth flow", exc_info=e)
raise raise
class TriggerProviderOAuthCallbackApi(Resource):
@setup_required
def get(self, provider):
"""Handle OAuth callback for trigger provider"""
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
# Use and validate proxy context
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
# Parse provider ID
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(
tenant_id=tenant_id,
provider_id=provider_id,
)
if oauth_client_params is None:
raise Forbidden("No OAuth client configuration found for this trigger provider")
# Get OAuth credentials from callback
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials:
raise Exception("Failed to get OAuth credentials")
# Save OAuth credentials to database
TriggerProviderService.add_trigger_provider(
tenant_id=tenant_id,
user_id=user_id,
provider_id=provider_id,
credential_type=CredentialType.OAUTH2,
credentials=dict(credentials),
expires_at=expires_at,
)
# Redirect to OAuth callback page
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class TriggerProviderOAuthRefreshTokenApi(Resource): class TriggerProviderOAuthRefreshTokenApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -257,16 +374,13 @@ class TriggerProviderOAuthClientManageApi(Resource):
try: try:
provider_id = TriggerProviderID(provider) provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
result = TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id, tenant_id=user.current_tenant_id,
provider_id=provider_id, provider_id=provider_id,
client_params=args.get("client_params"), client_params=args.get("client_params"),
enabled=args.get("enabled"), enabled=args.get("enabled"),
) )
return result
except ValueError as e: except ValueError as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
except Exception as e: except Exception as e:
@ -287,13 +401,10 @@ class TriggerProviderOAuthClientManageApi(Resource):
try: try:
provider_id = TriggerProviderID(provider) provider_id = TriggerProviderID(provider)
result = TriggerProviderService.delete_custom_oauth_client_params( return TriggerProviderService.delete_custom_oauth_client_params(
tenant_id=user.current_tenant_id, tenant_id=user.current_tenant_id,
provider_id=provider_id, provider_id=provider_id,
) )
return result
except ValueError as e: except ValueError as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
except Exception as e: except Exception as e:
@ -302,8 +413,12 @@ class TriggerProviderOAuthClientManageApi(Resource):
# Trigger provider endpoints # Trigger provider endpoints
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-provider/<path:provider>/list") api.add_resource(
api.add_resource(TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/<path:provider>/add") TriggerProviderCredentialListApi, "/workspaces/current/trigger-provider/credentials/<path:provider>/list"
)
api.add_resource(
TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/credentials/<path:provider>/add"
)
api.add_resource( api.add_resource(
TriggerProviderCredentialsUpdateApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/update" TriggerProviderCredentialsUpdateApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/update"
) )
@ -311,9 +426,11 @@ api.add_resource(
TriggerProviderCredentialsDeleteApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/delete" TriggerProviderCredentialsDeleteApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/delete"
) )
# OAuth
api.add_resource( api.add_resource(
TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize" TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
) )
api.add_resource(TriggerProviderOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource( api.add_resource(
TriggerProviderOAuthRefreshTokenApi, TriggerProviderOAuthRefreshTokenApi,
"/workspaces/current/trigger-provider/credentials/<path:credential_id>/oauth/refresh", "/workspaces/current/trigger-provider/credentials/<path:credential_id>/oauth/refresh",

View File

@ -71,15 +71,25 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
class TriggerProviderCredentialCache(ProviderCredentialsCache): class TriggerProviderCredentialCache(ProviderCredentialsCache):
"""Cache for trigger provider credentials""" """Cache for trigger provider credentials"""
def __init__(self, tenant_id: str, provider: str, credential_id: str): def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str: def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"] tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"] provider_id = kwargs["provider_id"]
credential_id = kwargs["credential_id"] credential_id = kwargs["credential_id"]
return f"trigger_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
class TriggerProviderOAuthClientCache(ProviderCredentialsCache):
"""Cache for trigger provider OAuth client"""
def __init__(self, tenant_id: str, provider_id: str):
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_id = kwargs["provider_id"]
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
class NoOpProviderCredentialCache: class NoOpProviderCredentialCache:
"""No-op provider credential cache""" """No-op provider credential cache"""

View File

@ -14,7 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
from core.trigger.entities import TriggerProviderEntity from core.trigger.entities.entities import TriggerProviderEntity
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))

View File

@ -1,7 +1,7 @@
from typing import Any from typing import Any
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity, PluginTriggerProviderEntity from core.plugin.entities.plugin_daemon import PluginTriggerProviderEntity
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
@ -15,15 +15,15 @@ class PluginTriggerManager(BasePluginClient):
for provider in json_response.get("data", []): for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {} declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name") provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []): for trigger in declaration.get("triggers", []):
tool["identity"]["provider"] = provider_name trigger["identity"]["provider"] = provider_name
return json_response return json_response
response = self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response(
"GET", "GET",
f"plugin/{tenant_id}/management/tools", f"plugin/{tenant_id}/management/triggers",
list[PluginToolProviderEntity], list[PluginTriggerProviderEntity],
params={"page": 1, "page_size": 256}, params={"page": 1, "page_size": 256},
transformer=transformer, transformer=transformer,
) )
@ -32,37 +32,36 @@ class PluginTriggerManager(BasePluginClient):
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name # override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.tools: for trigger in provider.declaration.triggers:
tool.identity.provider = provider.declaration.identity.name trigger.identity.provider = provider.declaration.identity.name
return response return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
""" """
Fetch tool provider for the given tenant and plugin. Fetch tool provider for the given tenant and plugin.
""" """
tool_provider_id = ToolProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data") data = json_response.get("data")
if data: if data:
for tool in data.get("declaration", {}).get("tools", []): for trigger in data.get("declaration", {}).get("triggers", []):
tool["identity"]["provider"] = tool_provider_id.provider_name trigger["identity"]["provider"] = provider_id.provider_name
return json_response return json_response
response = self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response(
"GET", "GET",
f"plugin/{tenant_id}/management/tool", f"plugin/{tenant_id}/management/trigger",
PluginToolProviderEntity, PluginTriggerProviderEntity,
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
transformer=transformer, transformer=transformer,
) )
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name # override the provider name for each trigger to plugin_id/provider_name
for tool in response.declaration.tools: for trigger in response.declaration.triggers:
tool.identity.provider = response.declaration.identity.name trigger.identity.provider = response.declaration.identity.name
return response return response

View File

@ -122,7 +122,6 @@ class ProviderConfigEncrypter:
self.provider_config_cache.set(data) self.provider_config_cache.set(data)
return data return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache

View File

@ -0,0 +1,40 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import (
OAuthSchema,
TriggerDescription,
TriggerEntity,
TriggerParameter,
TriggerProviderIdentity,
)
class TriggerProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider id of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
credentials: dict = Field(description="The credentials of the credential")
class TriggerProviderApiEntity(BaseModel):
identity: TriggerProviderIdentity = Field(description="The identity of the trigger provider")
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
oauth_schema: Optional[OAuthSchema] = Field(description="The OAuth schema of the trigger provider")
subscription_schema: list[ProviderConfig] = Field(description="The subscription schema of the trigger provider")
triggers: list[TriggerEntity] = Field(description="The triggers of the trigger provider")
class TriggerApiEntity(BaseModel):
name: str = Field(description="The name of the trigger")
description: TriggerDescription = Field(description="The description of the trigger")
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderCredentialApiEntity"]

View File

@ -4,18 +4,11 @@ from typing import Any, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.parameters import PluginParameterAutoGenerate, PluginParameterOption, PluginParameterTemplate
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
class TriggerParameterOption(BaseModel):
"""
The option of the trigger parameter
"""
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
class TriggerParameterType(StrEnum): class TriggerParameterType(StrEnum):
"""The type of the parameter""" """The type of the parameter"""
@ -32,20 +25,6 @@ class TriggerParameterType(StrEnum):
DYNAMIC_SELECT = "dynamic-select" DYNAMIC_SELECT = "dynamic-select"
class ParameterAutoGenerate(BaseModel):
"""Auto generation configuration for parameters"""
enabled: bool = Field(default=False, description="Whether auto generation is enabled")
template: Optional[str] = Field(default=None, description="Template for auto generation")
class ParameterTemplate(BaseModel):
"""Template configuration for parameters"""
value: str = Field(..., description="Template value")
type: str = Field(default="jinja2", description="Template type")
class TriggerParameter(BaseModel): class TriggerParameter(BaseModel):
""" """
The parameter of the trigger The parameter of the trigger
@ -54,17 +33,17 @@ class TriggerParameter(BaseModel):
name: str = Field(..., description="The name of the parameter") name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(..., description="The label presented to the user") label: I18nObject = Field(..., description="The label presented to the user")
type: TriggerParameterType = Field(..., description="The type of the parameter") type: TriggerParameterType = Field(..., description="The type of the parameter")
auto_generate: Optional[ParameterAutoGenerate] = Field( auto_generate: Optional[PluginParameterAutoGenerate] = Field(
default=None, description="The auto generate of the parameter" default=None, description="The auto generate of the parameter"
) )
template: Optional[ParameterTemplate] = Field(default=None, description="The template of the parameter") template: Optional[PluginParameterTemplate] = Field(default=None, description="The template of the parameter")
scope: Optional[str] = None scope: Optional[str] = None
required: Optional[bool] = False required: Optional[bool] = False
default: Union[int, float, str, None] = None default: Union[int, float, str, None] = None
min: Union[float, int, None] = None min: Union[float, int, None] = None
max: Union[float, int, None] = None max: Union[float, int, None] = None
precision: Optional[int] = None precision: Optional[int] = None
options: Optional[list[TriggerParameterOption]] = None options: Optional[list[PluginParameterOption]] = None
description: Optional[I18nObject] = None description: Optional[I18nObject] = None
@ -89,7 +68,7 @@ class TriggerIdentity(BaseModel):
author: str = Field(..., description="The author of the trigger") author: str = Field(..., description="The author of the trigger")
name: str = Field(..., description="The name of the trigger") name: str = Field(..., description="The name of the trigger")
label: I18nObject = Field(..., description="The label of the trigger") label: I18nObject = Field(..., description="The label of the trigger")
provider: str = Field(..., description="The provider of the trigger")
class TriggerDescription(BaseModel): class TriggerDescription(BaseModel):
""" """
@ -100,69 +79,23 @@ class TriggerDescription(BaseModel):
llm: I18nObject = Field(..., description="LLM readable description") llm: I18nObject = Field(..., description="LLM readable description")
class TriggerConfigurationExtraPython(BaseModel):
"""Python configuration for trigger"""
source: str = Field(..., description="The source file path for the trigger implementation")
class TriggerConfigurationExtra(BaseModel):
"""
The extra configuration for trigger
"""
class TriggerEntity(BaseModel): class TriggerEntity(BaseModel):
""" """
The configuration of a trigger The configuration of a trigger
""" """
python: TriggerConfigurationExtraPython
identity: TriggerIdentity = Field(..., description="The identity of the trigger") identity: TriggerIdentity = Field(..., description="The identity of the trigger")
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger") parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
description: TriggerDescription = Field(..., description="The description of the trigger") description: TriggerDescription = Field(..., description="The description of the trigger")
extra: TriggerConfigurationExtra = Field(..., description="The extra configuration of the trigger")
output_schema: Optional[Mapping[str, Any]] = Field( output_schema: Optional[Mapping[str, Any]] = Field(
default=None, description="The output schema that this trigger produces" default=None, description="The output schema that this trigger produces"
) )
class TriggerProviderConfigurationExtraPython(BaseModel):
"""Python configuration for trigger provider"""
source: str = Field(..., description="The source file path for the trigger provider implementation")
class TriggerProviderConfigurationExtra(BaseModel):
"""
The extra configuration for trigger provider
"""
python: TriggerProviderConfigurationExtraPython
class OAuthSchema(BaseModel): class OAuthSchema(BaseModel):
"""OAuth configuration schema""" client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
authorization_url: str = Field(..., description="OAuth authorization URL") default_factory=list, description="The schema of the OAuth credentials"
token_url: str = Field(..., description="OAuth token URL") )
client_id: str = Field(..., description="OAuth client ID")
client_secret: str = Field(..., description="OAuth client secret")
redirect_uri: Optional[str] = Field(default=None, description="OAuth redirect URI")
scope: Optional[str] = Field(default=None, description="OAuth scope")
class ProviderConfig(BaseModel):
"""Provider configuration item"""
name: str = Field(..., description="Configuration field name")
type: str = Field(..., description="Configuration field type")
required: bool = Field(default=False, description="Whether this field is required")
default: Any = Field(default=None, description="Default value")
label: Optional[I18nObject] = Field(default=None, description="Field label")
description: Optional[I18nObject] = Field(default=None, description="Field description")
options: Optional[list[dict[str, Any]]] = Field(default=None, description="Options for select type")
class TriggerProviderEntity(BaseModel): class TriggerProviderEntity(BaseModel):
""" """
@ -183,7 +116,6 @@ class TriggerProviderEntity(BaseModel):
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters", description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
) )
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider") triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
extra: TriggerProviderConfigurationExtra = Field(..., description="The extra configuration of the trigger provider")
class Subscription(BaseModel): class Subscription(BaseModel):
@ -223,21 +155,12 @@ class Unsubscription(BaseModel):
# Export all entities # Export all entities
__all__ = [ __all__ = [
"OAuthSchema", "OAuthSchema",
"ParameterAutoGenerate",
"ParameterTemplate",
"ProviderConfig",
"Subscription", "Subscription",
"TriggerConfigurationExtra",
"TriggerConfigurationExtraPython",
"TriggerDescription", "TriggerDescription",
"TriggerEntity", "TriggerEntity",
"TriggerEntity",
"TriggerIdentity", "TriggerIdentity",
"TriggerParameter", "TriggerParameter",
"TriggerParameterOption",
"TriggerParameterType", "TriggerParameterType",
"TriggerProviderConfigurationExtra",
"TriggerProviderConfigurationExtraPython",
"TriggerProviderEntity", "TriggerProviderEntity",
"TriggerProviderIdentity", "TriggerProviderIdentity",
"Unsubscription", "Unsubscription",

View File

@ -6,8 +6,11 @@ import logging
import time import time
from typing import Optional from typing import Optional
from core.entities.provider_entities import BasicProviderConfig
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities import ( from core.trigger.entities.api_entities import TriggerProviderApiEntity
from core.trigger.entities.entities import (
ProviderConfig, ProviderConfig,
Subscription, Subscription,
TriggerEntity, TriggerEntity,
@ -19,7 +22,7 @@ from core.trigger.entities import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TriggerProviderController: class PluginTriggerProviderController:
""" """
Controller for plugin trigger providers Controller for plugin trigger providers
""" """
@ -44,6 +47,18 @@ class TriggerProviderController:
self.plugin_id = plugin_id self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier self.plugin_unique_identifier = plugin_unique_identifier
def get_provider_id(self) -> TriggerProviderID:
"""
Get provider ID
"""
return TriggerProviderID(f"{self.plugin_id}/{self.entity.identity.name}")
def to_api_entity(self) -> TriggerProviderApiEntity:
"""
Convert to API entity
"""
return TriggerProviderApiEntity(**self.entity.model_dump())
@property @property
def identity(self) -> TriggerProviderIdentity: def identity(self) -> TriggerProviderIdentity:
"""Get provider identity""" """Get provider identity"""
@ -69,14 +84,6 @@ class TriggerProviderController:
return trigger return trigger
return None return None
def get_credentials_schema(self) -> list[ProviderConfig]:
"""
Get credentials schema for this provider
:return: List of provider config schemas
"""
return self.entity.credentials_schema
def get_subscription_schema(self) -> list[ProviderConfig]: def get_subscription_schema(self) -> list[ProviderConfig]:
""" """
Get subscription schema for this provider Get subscription schema for this provider
@ -109,18 +116,24 @@ class TriggerProviderController:
types.append(CredentialType.API_KEY) types.append(CredentialType.API_KEY)
return types return types
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: def get_credentials_schema(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
""" """
Get credentials schema by credential type Get credentials schema by credential type
:param credential_type: The type of credential (oauth or api_key) :param credential_type: The type of credential (oauth or api_key)
:return: List of provider config schemas :return: List of provider config schemas
""" """
if credential_type == CredentialType.OAUTH2.value: credential_type = CredentialType.of(credential_type) if isinstance(credential_type, str) else credential_type
if credential_type == CredentialType.OAUTH2:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value: if credential_type == CredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
"""
Get credential schema config by credential type
"""
return [x.to_basic_provider_config() for x in self.get_credentials_schema(credential_type)]
def get_oauth_client_schema(self) -> list[ProviderConfig]: def get_oauth_client_schema(self) -> list[ProviderConfig]:
""" """
@ -183,17 +196,5 @@ class TriggerProviderController:
logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id) logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id)
return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}") return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}")
def handle_webhook(self, webhook_path: str, request_data: dict, credentials: dict) -> dict:
"""
Handle incoming webhook through plugin runtime
:param webhook_path: Webhook path __all__ = ["PluginTriggerProviderController"]
:param request_data: Request data
:param credentials: Provider credentials
:return: Webhook handling result
"""
logger.info("Handling webhook for path %s for plugin %s", webhook_path, self.plugin_id)
return {"success": True, "path": webhook_path, "plugin": self.plugin_id, "data_received": request_data}
__all__ = ["TriggerProviderController"]

View File

@ -5,11 +5,14 @@ Trigger Manager for loading and managing trigger providers and triggers
import logging import logging
from typing import Optional from typing import Optional
from core.trigger.entities import ( from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.impl.trigger import PluginTriggerManager
from core.trigger.entities.entities import (
ProviderConfig, ProviderConfig,
Subscription,
TriggerEntity, TriggerEntity,
Unsubscription,
) )
from core.trigger.plugin_trigger import PluginTriggerController
from core.trigger.provider import PluginTriggerProviderController from core.trigger.provider import PluginTriggerProviderController
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,7 +31,7 @@ class TriggerManager:
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:return: List of trigger provider controllers :return: List of trigger provider controllers
""" """
manager = PluginTriggerController() manager = PluginTriggerManager()
provider_entities = manager.fetch_trigger_providers(tenant_id) provider_entities = manager.fetch_trigger_providers(tenant_id)
controllers = [] controllers = []
@ -48,22 +51,21 @@ class TriggerManager:
return controllers return controllers
@classmethod @classmethod
def get_plugin_trigger_provider( def get_trigger_provider(
cls, tenant_id: str, plugin_id: str, provider_name: str cls, tenant_id: str, provider_id: TriggerProviderID
) -> Optional[PluginTriggerProviderController]: ) -> PluginTriggerProviderController:
""" """
Get a specific plugin trigger provider Get a specific plugin trigger provider
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:return: Trigger provider controller or None :return: Trigger provider controller or None
""" """
manager = PluginTriggerManager() manager = PluginTriggerManager()
provider = manager.fetch_trigger_provider(tenant_id, plugin_id, provider_name) provider = manager.fetch_trigger_provider(tenant_id, provider_id)
if not provider: if not provider:
return None raise ValueError(f"Trigger provider {provider_id} not found")
try: try:
return PluginTriggerProviderController( return PluginTriggerProviderController(
@ -74,287 +76,139 @@ class TriggerManager:
) )
except Exception as e: except Exception as e:
logger.exception("Failed to load trigger provider") logger.exception("Failed to load trigger provider")
return None raise e
@classmethod @classmethod
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]: def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
""" """
List all trigger providers (plugin and builtin) List all trigger providers (plugin)
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:return: List of all trigger provider controllers :return: List of all trigger provider controllers
""" """
providers = [] return cls.list_plugin_trigger_providers(tenant_id)
# Get plugin providers
plugin_providers = cls.list_plugin_trigger_providers(tenant_id)
providers.extend(plugin_providers)
# TODO: Add builtin providers when implemented
# builtin_providers = cls.list_builtin_trigger_providers(tenant_id)
# providers.extend(builtin_providers)
return providers
@classmethod @classmethod
def list_triggers_by_provider(cls, tenant_id: str, plugin_id: str, provider_name: str) -> list[TriggerEntity]: def list_triggers_by_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerEntity]:
""" """
List all triggers for a specific provider List all triggers for a specific provider
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:return: List of trigger entities :return: List of trigger entities
""" """
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) provider = cls.get_trigger_provider(tenant_id, provider_id)
if not provider:
return []
return provider.get_triggers() return provider.get_triggers()
@classmethod @classmethod
def get_trigger( def get_trigger(
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str
) -> Optional[TriggerEntity]: ) -> Optional[TriggerEntity]:
""" """
Get a specific trigger Get a specific trigger
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:param trigger_name: Trigger name :param trigger_name: Trigger name
:return: Trigger entity or None :return: Trigger entity or None
""" """
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
if not provider:
return None
return provider.get_trigger(trigger_name)
@classmethod @classmethod
def validate_trigger_credentials( def validate_trigger_credentials(
cls, tenant_id: str, plugin_id: str, provider_name: str, credentials: dict cls, tenant_id: str, provider_id: TriggerProviderID, credentials: dict
) -> tuple[bool, str]: ) -> tuple[bool, str]:
""" """
Validate trigger provider credentials Validate trigger provider credentials
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:param credentials: Credentials to validate :param credentials: Credentials to validate
:return: Tuple of (is_valid, error_message) :return: Tuple of (is_valid, error_message)
""" """
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return False, "Provider not found"
try: try:
provider.validate_credentials(credentials) cls.get_trigger_provider(tenant_id, provider_id).validate_credentials(credentials)
return True, "" return True, ""
except Exception as e: except Exception as e:
return False, str(e) return False, str(e)
@classmethod @classmethod
def execute_trigger( def execute_trigger(
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str, parameters: dict, credentials: dict cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str, parameters: dict, credentials: dict
) -> dict: ) -> dict:
""" """
Execute a trigger Execute a trigger
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:param trigger_name: Trigger name :param trigger_name: Trigger name
:param parameters: Trigger parameters :param parameters: Trigger parameters
:param credentials: Provider credentials :param credentials: Provider credentials
:return: Trigger execution result :return: Trigger execution result
""" """
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) trigger = cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
trigger = provider.get_trigger(trigger_name)
if not trigger: if not trigger:
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_name}") raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
return cls.get_trigger_provider(tenant_id, provider_id).execute_trigger(trigger_name, parameters, credentials)
return provider.execute_trigger(trigger_name, parameters, credentials)
@classmethod @classmethod
def subscribe_trigger( def subscribe_trigger(
cls, cls,
tenant_id: str, tenant_id: str,
plugin_id: str, provider_id: TriggerProviderID,
provider_name: str,
trigger_name: str, trigger_name: str,
subscription_params: dict, subscription_params: dict,
credentials: dict, credentials: dict,
) -> dict: ) -> Subscription:
""" """
Subscribe to a trigger (e.g., register webhook) Subscribe to a trigger (e.g., register webhook)
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:param trigger_name: Trigger name :param trigger_name: Trigger name
:param subscription_params: Subscription parameters :param subscription_params: Subscription parameters
:param credentials: Provider credentials :param credentials: Provider credentials
:return: Subscription result :return: Subscription result
""" """
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) return cls.get_trigger_provider(tenant_id, provider_id).subscribe_trigger(
trigger_name, subscription_params, credentials
if not provider: )
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.subscribe_trigger(trigger_name, subscription_params, credentials)
@classmethod @classmethod
def unsubscribe_trigger( def unsubscribe_trigger(
cls, cls,
tenant_id: str, tenant_id: str,
plugin_id: str, provider_id: TriggerProviderID,
provider_name: str,
trigger_name: str, trigger_name: str,
subscription_metadata: dict, subscription_metadata: dict,
credentials: dict, credentials: dict,
) -> dict: ) -> Unsubscription:
""" """
Unsubscribe from a trigger Unsubscribe from a trigger
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:param trigger_name: Trigger name :param trigger_name: Trigger name
:param subscription_metadata: Subscription metadata from subscribe operation :param subscription_metadata: Subscription metadata from subscribe operation
:param credentials: Provider credentials :param credentials: Provider credentials
:return: Unsubscription result :return: Unsubscription result
""" """
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) return cls.get_trigger_provider(tenant_id, provider_id).unsubscribe_trigger(
trigger_name, subscription_metadata, credentials
if not provider: )
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.unsubscribe_trigger(trigger_name, subscription_metadata, credentials)
@classmethod
def handle_webhook(
cls,
tenant_id: str,
plugin_id: str,
provider_name: str,
webhook_path: str,
request_data: dict,
credentials: dict,
) -> dict:
"""
Handle incoming webhook for a trigger
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:param webhook_path: Webhook path
:param request_data: Webhook request data
:param credentials: Provider credentials
:return: Webhook handling result
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
return provider.handle_webhook(webhook_path, request_data, credentials)
@classmethod
def get_provider_credentials_schema(
cls, tenant_id: str, plugin_id: str, provider_name: str
) -> list[ProviderConfig]:
"""
Get provider credentials schema
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: List of provider config schemas
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return []
return provider.get_credentials_schema()
@classmethod @classmethod
def get_provider_subscription_schema( def get_provider_subscription_schema(
cls, tenant_id: str, plugin_id: str, provider_name: str cls, tenant_id: str, provider_id: TriggerProviderID
) -> list[ProviderConfig]: ) -> list[ProviderConfig]:
""" """
Get provider subscription schema Get provider subscription schema
:param tenant_id: Tenant ID :param tenant_id: Tenant ID
:param plugin_id: Plugin ID :param provider_id: Provider ID
:param provider_name: Provider name
:return: List of subscription config schemas :return: List of subscription config schemas
""" """
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
if not provider:
return []
return provider.get_subscription_schema()
@classmethod
def get_provider_info(cls, tenant_id: str, plugin_id: str, provider_name: str) -> Optional[dict]:
"""
Get provider information
:param tenant_id: Tenant ID
:param plugin_id: Plugin ID
:param provider_name: Provider name
:return: Provider info dict or None
"""
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
if not provider:
return None
return {
"plugin_id": plugin_id,
"provider_name": provider_name,
"identity": provider.entity.identity.model_dump() if provider.entity.identity else {},
"credentials_schema": [c.model_dump() for c in provider.entity.credentials_schema],
"subscription_schema": [s.model_dump() for s in provider.entity.subscription_schema],
"oauth_enabled": provider.entity.oauth_schema is not None,
"trigger_count": len(provider.entity.triggers),
"triggers": [t.identity.model_dump() for t in provider.entity.triggers],
}
@classmethod
def list_providers_for_workflow(cls, tenant_id: str) -> list[dict]:
"""
List trigger providers suitable for workflow usage
:param tenant_id: Tenant ID
:return: List of provider info dicts
"""
providers = cls.list_all_trigger_providers(tenant_id)
result = []
for provider in providers:
info = {
"plugin_id": provider.plugin_id,
"provider_name": provider.entity.identity.name,
"label": provider.entity.identity.label.model_dump(),
"description": provider.entity.identity.description.model_dump(),
"icon": provider.entity.identity.icon,
"trigger_count": len(provider.entity.triggers),
}
result.append(info)
return result
# Export # Export
__all__ = ["TriggerManager"] __all__ = ["TriggerManager"]

View File

@ -0,0 +1,49 @@
from core.helper.provider_cache import TriggerProviderCredentialCache, TriggerProviderOAuthClientCache
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
from core.trigger.provider import PluginTriggerProviderController
from models.trigger import TriggerProvider
def create_trigger_provider_encrypter_for_credential(
tenant_id: str,
controller: PluginTriggerProviderController,
credential: TriggerProvider | TriggerProviderCredentialApiEntity,
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(credential.credential_type),
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential.id,
),
)
def create_trigger_provider_encrypter(
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(credential_type),
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential_id,
),
)
def create_trigger_provider_oauth_encrypter(
tenant_id: str, controller: PluginTriggerProviderController
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
cache=TriggerProviderOAuthClientCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
),
)

View File

@ -1,5 +1,6 @@
import json import json
from datetime import UTC, datetime import time
from datetime import datetime
from typing import cast from typing import cast
import sqlalchemy as sa import sqlalchemy as sa
@ -7,6 +8,7 @@ from sqlalchemy import DateTime, Index, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
from models.base import Base from models.base import Base
from models.types import StringUUID from models.types import StringUUID
@ -45,20 +47,23 @@ class TriggerProvider(Base):
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
return {} return {}
@property
def credentials_str(self) -> str:
"""Get credentials as string"""
return self.encrypted_credentials or "{}"
def is_oauth_expired(self) -> bool: def is_oauth_expired(self) -> bool:
"""Check if OAuth token is expired""" """Check if OAuth token is expired"""
if self.credential_type != CredentialType.OAUTH2.value: if self.credential_type != CredentialType.OAUTH2.value:
return False return False
if self.expires_at == -1: if self.expires_at == -1:
return False return False
# Check if token expires in next 60 seconds # Check if token expires in next 3 minutes
return (self.expires_at - 60) < int(datetime.now(UTC).timestamp()) return (self.expires_at - 180) < int(time.time())
def to_api_entity(self) -> TriggerProviderCredentialApiEntity:
return TriggerProviderCredentialApiEntity(
id=self.id,
name=self.name,
provider=self.provider_id,
credential_type=CredentialType(self.credential_type),
credentials=self.credentials,
)
# system level trigger oauth client params # system level trigger oauth client params
class TriggerOAuthSystemClient(Base): class TriggerOAuthSystemClient(Base):

View File

@ -13,14 +13,20 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.plugin.service import PluginService from core.tools.utils.encryption import (
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter create_provider_encrypter,
)
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity
from core.trigger.trigger_manager import TriggerManager from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import (
create_trigger_provider_encrypter_for_credential,
create_trigger_provider_oauth_encrypter,
)
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider
from services.plugin.oauth_service import OAuthProxyService from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,13 +37,34 @@ class TriggerProviderService:
__MAX_TRIGGER_PROVIDER_COUNT__ = 100 __MAX_TRIGGER_PROVIDER_COUNT__ = 100
@classmethod @classmethod
def list_trigger_providers(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerProvider]: def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
"""List all trigger providers for the current tenant""" """List all trigger providers for the current tenant"""
# TODO fetch trigger plugin controller return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
# TODO fetch all trigger plugin credentials @classmethod
def list_trigger_provider_credentials(
cls, tenant_id: str, provider_id: TriggerProviderID
) -> list[TriggerProviderCredentialApiEntity]:
"""List all trigger providers for the current tenant"""
credentials: list[TriggerProviderCredentialApiEntity] = []
with Session(db.engine, autoflush=False) as session: with Session(db.engine, autoflush=False) as session:
return session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).all() credentials_db = (
session.query(TriggerProvider)
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
.order_by(desc(TriggerProvider.created_at))
.all()
)
credentials = [credential.to_api_entity() for credential in credentials_db]
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
for credential in credentials:
encrypter, _ = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id,
controller=provider_controller,
credential=credential,
)
credential.credentials = encrypter.decrypt(credential.credentials)
return credentials
@classmethod @classmethod
def add_trigger_provider( def add_trigger_provider(
@ -63,6 +90,7 @@ class TriggerProviderService:
:return: Success response :return: Success response
""" """
try: try:
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine) as session: with Session(db.engine) as session:
# Use distributed lock to prevent race conditions # Use distributed lock to prevent race conditions
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}" lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
@ -96,10 +124,9 @@ class TriggerProviderService:
if existing: if existing:
raise ValueError(f"Credential name '{name}' already exists for this provider") raise ValueError(f"Credential name '{name}' already exists for this provider")
# Create encrypter for credentials
encrypter, _ = create_provider_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[], # We'll define schema later in TriggerProvider classes config=provider_controller.get_credential_schema_config(credential_type),
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
) )
@ -141,20 +168,21 @@ class TriggerProviderService:
:return: Success response :return: Success response
""" """
with Session(db.engine) as session: with Session(db.engine) as session:
# Get provider
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
if not db_provider: if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found") raise ValueError(f"Trigger provider credential {credential_id} not found")
try: try:
# Update credentials if provided provider_controller = TriggerManager.get_trigger_provider(
if credentials: tenant_id, TriggerProviderID(db_provider.provider_id)
encrypter, cache = cls._create_provider_encrypter( )
tenant_id=tenant_id,
provider=db_provider,
)
if credentials:
encrypter, cache = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id,
controller=provider_controller,
credential=db_provider,
)
# Handle hidden values # Handle hidden values
original_credentials = encrypter.decrypt(db_provider.credentials) original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials = { new_credentials = {
@ -200,14 +228,20 @@ class TriggerProviderService:
if not db_provider: if not db_provider:
raise ValueError(f"Trigger provider credential {credential_id} not found") raise ValueError(f"Trigger provider credential {credential_id} not found")
# Delete provider provider_controller = TriggerManager.get_trigger_provider(
tenant_id, TriggerProviderID(db_provider.provider_id)
)
# Clear cache
_, cache = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id,
controller=provider_controller,
credential=db_provider,
)
session.delete(db_provider) session.delete(db_provider)
session.commit() session.commit()
# Clear cache
_, cache = cls._create_provider_encrypter(tenant_id, db_provider)
cache.delete() cache.delete()
return {"result": "success"} return {"result": "success"}
@classmethod @classmethod
@ -232,13 +266,13 @@ class TriggerProviderService:
if db_provider.credential_type != CredentialType.OAUTH2.value: if db_provider.credential_type != CredentialType.OAUTH2.value:
raise ValueError("Only OAuth credentials can be refreshed") raise ValueError("Only OAuth credentials can be refreshed")
# Parse provider ID
provider_id = TriggerProviderID(db_provider.provider_id) provider_id = TriggerProviderID(db_provider.provider_id)
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Create encrypter # Create encrypter
encrypter, cache = cls._create_provider_encrypter( encrypter, cache = create_trigger_provider_encrypter_for_credential(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=db_provider, controller=provider_controller,
credential=db_provider,
) )
# Decrypt current credentials # Decrypt current credentials
@ -285,18 +319,8 @@ class TriggerProviderService:
:param provider_id: Provider identifier :param provider_id: Provider identifier
:return: OAuth client configuration or None :return: OAuth client configuration or None
""" """
# Get trigger provider controller to access schema provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
# Create encrypter for OAuth client params
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine, autoflush=False) as session: with Session(db.engine, autoflush=False) as session:
# First check for tenant-specific OAuth client
tenant_client: TriggerOAuthTenantClient | None = ( tenant_client: TriggerOAuthTenantClient | None = (
session.query(TriggerOAuthTenantClient) session.query(TriggerOAuthTenantClient)
.filter_by( .filter_by(
@ -310,10 +334,10 @@ class TriggerProviderService:
oauth_params: Mapping[str, Any] | None = None oauth_params: Mapping[str, Any] | None = None
if tenant_client: if tenant_client:
encrypter, _ = create_trigger_provider_oauth_encrypter(tenant_id, provider_controller)
oauth_params = encrypter.decrypt(tenant_client.oauth_params) oauth_params = encrypter.decrypt(tenant_client.oauth_params)
return oauth_params return oauth_params
# Only verified plugins can use system OAuth client
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id) is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
if not is_verified: if not is_verified:
return oauth_params return oauth_params
@ -354,7 +378,7 @@ class TriggerProviderService:
return {"result": "success"} return {"result": "success"}
# Get provider controller to access schema # Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id) provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
with Session(db.engine) as session: with Session(db.engine) as session:
# Find existing custom client params # Find existing custom client params
@ -425,7 +449,7 @@ class TriggerProviderService:
return {} return {}
# Get provider controller to access schema # Get provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id) provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
# Create encrypter to decrypt and mask values # Create encrypter to decrypt and mask values
encrypter, _ = create_provider_encrypter( encrypter, _ = create_provider_encrypter(
@ -477,63 +501,6 @@ class TriggerProviderService:
) )
return custom_client is not None return custom_client is not None
@classmethod
def create_oauth_proxy_context(
cls,
tenant_id: str,
user_id: str,
provider_id: TriggerProviderID,
) -> str:
"""
Create OAuth proxy context for authorization flow.
:param tenant_id: Tenant ID
:param user_id: User ID
:param provider: Provider identifier
:return: Context ID for OAuth flow
"""
return OAuthProxyService.create_proxy_context(
user_id=user_id,
tenant_id=tenant_id,
plugin_id=provider_id.plugin_id,
provider=provider_id.provider_name,
)
@classmethod
def _create_provider_encrypter(
cls, tenant_id: str, provider: TriggerProvider
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
Create encrypter and cache for trigger provider credentials
:param tenant_id: Tenant ID
:param provider: TriggerProvider instance
:return: Tuple of encrypter and cache
"""
from core.helper.provider_cache import TriggerProviderCredentialCache
# Parse provider ID
provider_id = TriggerProviderID(provider.provider_id)
# Get trigger provider controller to access schema
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
# Create encrypter with appropriate schema based on credential type
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(provider.credential_type)
],
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider=provider.provider_id,
credential_id=provider.id,
),
)
return encrypter, cache
@classmethod @classmethod
def _generate_provider_name( def _generate_provider_name(
cls, cls,

View File

@ -1,5 +1,4 @@
import logging import logging
from typing import Any
from flask import Request, Response from flask import Request, Response
@ -21,3 +20,4 @@ class TriggerService:
# TODO dispatch by the trigger controller # TODO dispatch by the trigger controller
# TODO using the dispatch result(events) to invoke the trigger events # TODO using the dispatch result(events) to invoke the trigger events
raise NotImplementedError("Not implemented")