mirror of https://github.com/langgenius/dify.git
feat(trigger): implement complete OAuth authorization flow for trigger providers
- Add OAuth authorization URL generation API endpoint - Implement OAuth callback handler for credential storage - Support both system-level and tenant-level OAuth clients - Add trigger provider credential encryption utilities - Refactor trigger entities into separate modules - Update trigger provider service with OAuth client management - Add credential cache for trigger providers 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
87120ad4ac
commit
a46c9238fa
|
|
@ -1,20 +1,37 @@
|
|||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderCredentialListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -27,8 +44,10 @@ class TriggerProviderListApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return TriggerProviderService.list_trigger_providers(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_credentials(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
|
|
@ -145,30 +164,128 @@ class TriggerProviderOAuthAuthorizeApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a provider"""
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
context_id = TriggerProviderService.create_oauth_proxy_context(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# TODO: Build OAuth authorization URL
|
||||
# This will be implemented when we have provider-specific OAuth configs
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
return {
|
||||
"context_id": context_id,
|
||||
"authorization_url": f"/oauth/authorize?context={context_id}",
|
||||
}
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerProviderOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise Exception("Failed to get OAuth credentials")
|
||||
|
||||
# Save OAuth credentials to database
|
||||
TriggerProviderService.add_trigger_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
credentials=dict(credentials),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerProviderOAuthRefreshTokenApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
|
@ -257,16 +374,13 @@ class TriggerProviderOAuthClientManageApi(Resource):
|
|||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
result = TriggerProviderService.save_custom_oauth_client_params(
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
|
|
@ -287,13 +401,10 @@ class TriggerProviderOAuthClientManageApi(Resource):
|
|||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
result = TriggerProviderService.delete_custom_oauth_client_params(
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
|
|
@ -302,8 +413,12 @@ class TriggerProviderOAuthClientManageApi(Resource):
|
|||
|
||||
|
||||
# Trigger provider endpoints
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-provider/<path:provider>/list")
|
||||
api.add_resource(TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/<path:provider>/add")
|
||||
api.add_resource(
|
||||
TriggerProviderCredentialListApi, "/workspaces/current/trigger-provider/credentials/<path:provider>/list"
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/credentials/<path:provider>/add"
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderCredentialsUpdateApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/update"
|
||||
)
|
||||
|
|
@ -311,9 +426,11 @@ api.add_resource(
|
|||
TriggerProviderCredentialsDeleteApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/delete"
|
||||
)
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerProviderOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthRefreshTokenApi,
|
||||
"/workspaces/current/trigger-provider/credentials/<path:credential_id>/oauth/refresh",
|
||||
|
|
|
|||
|
|
@ -71,15 +71,25 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
|||
class TriggerProviderCredentialCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
def __init__(self, tenant_id: str, provider_id: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
|
||||
|
||||
class TriggerProviderOAuthClientCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider OAuth client"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider_id=provider_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_id = kwargs["provider_id"]
|
||||
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption
|
|||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities import TriggerProviderEntity
|
||||
from core.trigger.entities.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity, PluginTriggerProviderEntity
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginTriggerProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
|
|
@ -15,15 +15,15 @@ class PluginTriggerManager(BasePluginClient):
|
|||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
for trigger in declaration.get("triggers", []):
|
||||
trigger["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tools",
|
||||
list[PluginToolProviderEntity],
|
||||
f"plugin/{tenant_id}/management/triggers",
|
||||
list[PluginTriggerProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
|
@ -32,37 +32,36 @@ class PluginTriggerManager(BasePluginClient):
|
|||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
for trigger in provider.declaration.triggers:
|
||||
trigger.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
def fetch_trigger_provider(self, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = tool_provider_id.provider_name
|
||||
for trigger in data.get("declaration", {}).get("triggers", []):
|
||||
trigger["identity"]["provider"] = provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tool",
|
||||
PluginToolProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
f"plugin/{tenant_id}/management/trigger",
|
||||
PluginTriggerProviderEntity,
|
||||
params={"provider": provider_id.provider_name, "plugin_id": provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
# override the provider name for each trigger to plugin_id/provider_name
|
||||
for trigger in response.declaration.triggers:
|
||||
trigger.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -122,7 +122,6 @@ class ProviderConfigEncrypter:
|
|||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import (
|
||||
OAuthSchema,
|
||||
TriggerDescription,
|
||||
TriggerEntity,
|
||||
TriggerParameter,
|
||||
TriggerProviderIdentity,
|
||||
)
|
||||
|
||||
|
||||
class TriggerProviderCredentialApiEntity(BaseModel):
|
||||
id: str = Field(description="The unique id of the credential")
|
||||
name: str = Field(description="The name of the credential")
|
||||
provider: str = Field(description="The provider id of the credential")
|
||||
credential_type: CredentialType = Field(description="The type of the credential")
|
||||
credentials: dict = Field(description="The credentials of the credential")
|
||||
|
||||
|
||||
class TriggerProviderApiEntity(BaseModel):
|
||||
identity: TriggerProviderIdentity = Field(description="The identity of the trigger provider")
|
||||
credentials_schema: list[ProviderConfig] = Field(description="The credentials schema of the trigger provider")
|
||||
oauth_schema: Optional[OAuthSchema] = Field(description="The OAuth schema of the trigger provider")
|
||||
subscription_schema: list[ProviderConfig] = Field(description="The subscription schema of the trigger provider")
|
||||
triggers: list[TriggerEntity] = Field(description="The triggers of the trigger provider")
|
||||
|
||||
|
||||
class TriggerApiEntity(BaseModel):
|
||||
name: str = Field(description="The name of the trigger")
|
||||
description: TriggerDescription = Field(description="The description of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(description="The parameters of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(description="The output schema of the trigger")
|
||||
|
||||
|
||||
__all__ = ["TriggerApiEntity", "TriggerProviderApiEntity", "TriggerProviderCredentialApiEntity"]
|
||||
|
|
@ -4,18 +4,11 @@ from typing import Any, Optional, Union
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import PluginParameterAutoGenerate, PluginParameterOption, PluginParameterTemplate
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class TriggerParameterOption(BaseModel):
|
||||
"""
|
||||
The option of the trigger parameter
|
||||
"""
|
||||
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
|
||||
class TriggerParameterType(StrEnum):
|
||||
"""The type of the parameter"""
|
||||
|
||||
|
|
@ -32,20 +25,6 @@ class TriggerParameterType(StrEnum):
|
|||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
|
||||
class ParameterAutoGenerate(BaseModel):
|
||||
"""Auto generation configuration for parameters"""
|
||||
|
||||
enabled: bool = Field(default=False, description="Whether auto generation is enabled")
|
||||
template: Optional[str] = Field(default=None, description="Template for auto generation")
|
||||
|
||||
|
||||
class ParameterTemplate(BaseModel):
|
||||
"""Template configuration for parameters"""
|
||||
|
||||
value: str = Field(..., description="Template value")
|
||||
type: str = Field(default="jinja2", description="Template type")
|
||||
|
||||
|
||||
class TriggerParameter(BaseModel):
|
||||
"""
|
||||
The parameter of the trigger
|
||||
|
|
@ -54,17 +33,17 @@ class TriggerParameter(BaseModel):
|
|||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
type: TriggerParameterType = Field(..., description="The type of the parameter")
|
||||
auto_generate: Optional[ParameterAutoGenerate] = Field(
|
||||
auto_generate: Optional[PluginParameterAutoGenerate] = Field(
|
||||
default=None, description="The auto generate of the parameter"
|
||||
)
|
||||
template: Optional[ParameterTemplate] = Field(default=None, description="The template of the parameter")
|
||||
template: Optional[PluginParameterTemplate] = Field(default=None, description="The template of the parameter")
|
||||
scope: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
default: Union[int, float, str, None] = None
|
||||
min: Union[float, int, None] = None
|
||||
max: Union[float, int, None] = None
|
||||
precision: Optional[int] = None
|
||||
options: Optional[list[TriggerParameterOption]] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
description: Optional[I18nObject] = None
|
||||
|
||||
|
||||
|
|
@ -89,7 +68,7 @@ class TriggerIdentity(BaseModel):
|
|||
author: str = Field(..., description="The author of the trigger")
|
||||
name: str = Field(..., description="The name of the trigger")
|
||||
label: I18nObject = Field(..., description="The label of the trigger")
|
||||
|
||||
provider: str = Field(..., description="The provider of the trigger")
|
||||
|
||||
class TriggerDescription(BaseModel):
|
||||
"""
|
||||
|
|
@ -100,69 +79,23 @@ class TriggerDescription(BaseModel):
|
|||
llm: I18nObject = Field(..., description="LLM readable description")
|
||||
|
||||
|
||||
class TriggerConfigurationExtraPython(BaseModel):
|
||||
"""Python configuration for trigger"""
|
||||
|
||||
source: str = Field(..., description="The source file path for the trigger implementation")
|
||||
|
||||
|
||||
class TriggerConfigurationExtra(BaseModel):
|
||||
"""
|
||||
The extra configuration for trigger
|
||||
"""
|
||||
|
||||
|
||||
class TriggerEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger
|
||||
"""
|
||||
|
||||
python: TriggerConfigurationExtraPython
|
||||
identity: TriggerIdentity = Field(..., description="The identity of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
|
||||
description: TriggerDescription = Field(..., description="The description of the trigger")
|
||||
extra: TriggerConfigurationExtra = Field(..., description="The extra configuration of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(
|
||||
default=None, description="The output schema that this trigger produces"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProviderConfigurationExtraPython(BaseModel):
|
||||
"""Python configuration for trigger provider"""
|
||||
|
||||
source: str = Field(..., description="The source file path for the trigger provider implementation")
|
||||
|
||||
|
||||
class TriggerProviderConfigurationExtra(BaseModel):
|
||||
"""
|
||||
The extra configuration for trigger provider
|
||||
"""
|
||||
|
||||
python: TriggerProviderConfigurationExtraPython
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
"""OAuth configuration schema"""
|
||||
|
||||
authorization_url: str = Field(..., description="OAuth authorization URL")
|
||||
token_url: str = Field(..., description="OAuth token URL")
|
||||
client_id: str = Field(..., description="OAuth client ID")
|
||||
client_secret: str = Field(..., description="OAuth client secret")
|
||||
redirect_uri: Optional[str] = Field(default=None, description="OAuth redirect URI")
|
||||
scope: Optional[str] = Field(default=None, description="OAuth scope")
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Provider configuration item"""
|
||||
|
||||
name: str = Field(..., description="Configuration field name")
|
||||
type: str = Field(..., description="Configuration field type")
|
||||
required: bool = Field(default=False, description="Whether this field is required")
|
||||
default: Any = Field(default=None, description="Default value")
|
||||
label: Optional[I18nObject] = Field(default=None, description="Field label")
|
||||
description: Optional[I18nObject] = Field(default=None, description="Field description")
|
||||
options: Optional[list[dict[str, Any]]] = Field(default=None, description="Options for select type")
|
||||
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
class TriggerProviderEntity(BaseModel):
|
||||
"""
|
||||
|
|
@ -183,7 +116,6 @@ class TriggerProviderEntity(BaseModel):
|
|||
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
|
||||
)
|
||||
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
|
||||
extra: TriggerProviderConfigurationExtra = Field(..., description="The extra configuration of the trigger provider")
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
|
|
@ -223,21 +155,12 @@ class Unsubscription(BaseModel):
|
|||
# Export all entities
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
"ParameterAutoGenerate",
|
||||
"ParameterTemplate",
|
||||
"ProviderConfig",
|
||||
"Subscription",
|
||||
"TriggerConfigurationExtra",
|
||||
"TriggerConfigurationExtraPython",
|
||||
"TriggerDescription",
|
||||
"TriggerEntity",
|
||||
"TriggerEntity",
|
||||
"TriggerIdentity",
|
||||
"TriggerParameter",
|
||||
"TriggerParameterOption",
|
||||
"TriggerParameterType",
|
||||
"TriggerProviderConfigurationExtra",
|
||||
"TriggerProviderConfigurationExtraPython",
|
||||
"TriggerProviderEntity",
|
||||
"TriggerProviderIdentity",
|
||||
"Unsubscription",
|
||||
|
|
@ -6,8 +6,11 @@ import logging
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities import (
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
TriggerEntity,
|
||||
|
|
@ -19,7 +22,7 @@ from core.trigger.entities import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderController:
|
||||
class PluginTriggerProviderController:
|
||||
"""
|
||||
Controller for plugin trigger providers
|
||||
"""
|
||||
|
|
@ -44,6 +47,18 @@ class TriggerProviderController:
|
|||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_provider_id(self) -> TriggerProviderID:
|
||||
"""
|
||||
Get provider ID
|
||||
"""
|
||||
return TriggerProviderID(f"{self.plugin_id}/{self.entity.identity.name}")
|
||||
|
||||
def to_api_entity(self) -> TriggerProviderApiEntity:
|
||||
"""
|
||||
Convert to API entity
|
||||
"""
|
||||
return TriggerProviderApiEntity(**self.entity.model_dump())
|
||||
|
||||
@property
|
||||
def identity(self) -> TriggerProviderIdentity:
|
||||
"""Get provider identity"""
|
||||
|
|
@ -69,14 +84,6 @@ class TriggerProviderController:
|
|||
return trigger
|
||||
return None
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get credentials schema for this provider
|
||||
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
return self.entity.credentials_schema
|
||||
|
||||
def get_subscription_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get subscription schema for this provider
|
||||
|
|
@ -109,18 +116,24 @@ class TriggerProviderController:
|
|||
types.append(CredentialType.API_KEY)
|
||||
return types
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
def get_credentials_schema(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get credentials schema by credential type
|
||||
|
||||
:param credential_type: The type of credential (oauth or api_key)
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
credential_type = CredentialType.of(credential_type) if isinstance(credential_type, str) else credential_type
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY.value:
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_credential_schema_config(self, credential_type: CredentialType | str) -> list[BasicProviderConfig]:
|
||||
"""
|
||||
Get credential schema config by credential type
|
||||
"""
|
||||
return [x.to_basic_provider_config() for x in self.get_credentials_schema(credential_type)]
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
|
|
@ -183,17 +196,5 @@ class TriggerProviderController:
|
|||
logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id)
|
||||
return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}")
|
||||
|
||||
def handle_webhook(self, webhook_path: str, request_data: dict, credentials: dict) -> dict:
|
||||
"""
|
||||
Handle incoming webhook through plugin runtime
|
||||
|
||||
:param webhook_path: Webhook path
|
||||
:param request_data: Request data
|
||||
:param credentials: Provider credentials
|
||||
:return: Webhook handling result
|
||||
"""
|
||||
logger.info("Handling webhook for path %s for plugin %s", webhook_path, self.plugin_id)
|
||||
return {"success": True, "path": webhook_path, "plugin": self.plugin_id, "data_received": request_data}
|
||||
|
||||
|
||||
__all__ = ["TriggerProviderController"]
|
||||
__all__ = ["PluginTriggerProviderController"]
|
||||
|
|
|
|||
|
|
@ -5,11 +5,14 @@ Trigger Manager for loading and managing trigger providers and triggers
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.trigger.entities import (
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.impl.trigger import PluginTriggerManager
|
||||
from core.trigger.entities.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
TriggerEntity,
|
||||
Unsubscription,
|
||||
)
|
||||
from core.trigger.plugin_trigger import PluginTriggerController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -28,7 +31,7 @@ class TriggerManager:
|
|||
:param tenant_id: Tenant ID
|
||||
:return: List of trigger provider controllers
|
||||
"""
|
||||
manager = PluginTriggerController()
|
||||
manager = PluginTriggerManager()
|
||||
provider_entities = manager.fetch_trigger_providers(tenant_id)
|
||||
|
||||
controllers = []
|
||||
|
|
@ -48,22 +51,21 @@ class TriggerManager:
|
|||
return controllers
|
||||
|
||||
@classmethod
|
||||
def get_plugin_trigger_provider(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str
|
||||
) -> Optional[PluginTriggerProviderController]:
|
||||
def get_trigger_provider(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> PluginTriggerProviderController:
|
||||
"""
|
||||
Get a specific plugin trigger provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:return: Trigger provider controller or None
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider = manager.fetch_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
provider = manager.fetch_trigger_provider(tenant_id, provider_id)
|
||||
|
||||
if not provider:
|
||||
return None
|
||||
raise ValueError(f"Trigger provider {provider_id} not found")
|
||||
|
||||
try:
|
||||
return PluginTriggerProviderController(
|
||||
|
|
@ -74,287 +76,139 @@ class TriggerManager:
|
|||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to load trigger provider")
|
||||
return None
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||
"""
|
||||
List all trigger providers (plugin and builtin)
|
||||
List all trigger providers (plugin)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of all trigger provider controllers
|
||||
"""
|
||||
providers = []
|
||||
|
||||
# Get plugin providers
|
||||
plugin_providers = cls.list_plugin_trigger_providers(tenant_id)
|
||||
providers.extend(plugin_providers)
|
||||
|
||||
# TODO: Add builtin providers when implemented
|
||||
# builtin_providers = cls.list_builtin_trigger_providers(tenant_id)
|
||||
# providers.extend(builtin_providers)
|
||||
|
||||
return providers
|
||||
return cls.list_plugin_trigger_providers(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def list_triggers_by_provider(cls, tenant_id: str, plugin_id: str, provider_name: str) -> list[TriggerEntity]:
|
||||
def list_triggers_by_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerEntity]:
|
||||
"""
|
||||
List all triggers for a specific provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:return: List of trigger entities
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return []
|
||||
|
||||
provider = cls.get_trigger_provider(tenant_id, provider_id)
|
||||
return provider.get_triggers()
|
||||
|
||||
@classmethod
|
||||
def get_trigger(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str
|
||||
) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:return: Trigger entity or None
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
return provider.get_trigger(trigger_name)
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
|
||||
|
||||
@classmethod
|
||||
def validate_trigger_credentials(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str, credentials: dict
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, credentials: dict
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate trigger provider credentials
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:param credentials: Credentials to validate
|
||||
:return: Tuple of (is_valid, error_message)
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return False, "Provider not found"
|
||||
|
||||
try:
|
||||
provider.validate_credentials(credentials)
|
||||
cls.get_trigger_provider(tenant_id, provider_id).validate_credentials(credentials)
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def execute_trigger(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str, parameters: dict, credentials: dict
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str, parameters: dict, credentials: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Execute a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
trigger = provider.get_trigger(trigger_name)
|
||||
trigger = cls.get_trigger_provider(tenant_id, provider_id).get_trigger(trigger_name)
|
||||
if not trigger:
|
||||
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_name}")
|
||||
|
||||
return provider.execute_trigger(trigger_name, parameters, credentials)
|
||||
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_id}")
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).execute_trigger(trigger_name, parameters, credentials)
|
||||
|
||||
@classmethod
|
||||
def subscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider_name: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
subscription_params: dict,
|
||||
credentials: dict,
|
||||
) -> dict:
|
||||
) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger (e.g., register webhook)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_params: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
return provider.subscribe_trigger(trigger_name, subscription_params, credentials)
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).subscribe_trigger(
|
||||
trigger_name, subscription_params, credentials
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unsubscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider_name: str,
|
||||
provider_id: TriggerProviderID,
|
||||
trigger_name: str,
|
||||
subscription_metadata: dict,
|
||||
credentials: dict,
|
||||
) -> dict:
|
||||
) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_metadata: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
return provider.unsubscribe_trigger(trigger_name, subscription_metadata, credentials)
|
||||
|
||||
@classmethod
|
||||
def handle_webhook(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider_name: str,
|
||||
webhook_path: str,
|
||||
request_data: dict,
|
||||
credentials: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Handle incoming webhook for a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param webhook_path: Webhook path
|
||||
:param request_data: Webhook request data
|
||||
:param credentials: Provider credentials
|
||||
:return: Webhook handling result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
return provider.handle_webhook(webhook_path, request_data, credentials)
|
||||
|
||||
@classmethod
|
||||
def get_provider_credentials_schema(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str
|
||||
) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get provider credentials schema
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return []
|
||||
|
||||
return provider.get_credentials_schema()
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).unsubscribe_trigger(
|
||||
trigger_name, subscription_metadata, credentials
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get provider subscription schema
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param provider_id: Provider ID
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return []
|
||||
|
||||
return provider.get_subscription_schema()
|
||||
|
||||
@classmethod
|
||||
def get_provider_info(cls, tenant_id: str, plugin_id: str, provider_name: str) -> Optional[dict]:
|
||||
"""
|
||||
Get provider information
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:return: Provider info dict or None
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"provider_name": provider_name,
|
||||
"identity": provider.entity.identity.model_dump() if provider.entity.identity else {},
|
||||
"credentials_schema": [c.model_dump() for c in provider.entity.credentials_schema],
|
||||
"subscription_schema": [s.model_dump() for s in provider.entity.subscription_schema],
|
||||
"oauth_enabled": provider.entity.oauth_schema is not None,
|
||||
"trigger_count": len(provider.entity.triggers),
|
||||
"triggers": [t.identity.model_dump() for t in provider.entity.triggers],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_providers_for_workflow(cls, tenant_id: str) -> list[dict]:
|
||||
"""
|
||||
List trigger providers suitable for workflow usage
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of provider info dicts
|
||||
"""
|
||||
providers = cls.list_all_trigger_providers(tenant_id)
|
||||
|
||||
result = []
|
||||
for provider in providers:
|
||||
info = {
|
||||
"plugin_id": provider.plugin_id,
|
||||
"provider_name": provider.entity.identity.name,
|
||||
"label": provider.entity.identity.label.model_dump(),
|
||||
"description": provider.entity.identity.description.model_dump(),
|
||||
"icon": provider.entity.identity.icon,
|
||||
"trigger_count": len(provider.entity.triggers),
|
||||
}
|
||||
result.append(info)
|
||||
|
||||
return result
|
||||
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
|
||||
|
||||
# Export
|
||||
__all__ = ["TriggerManager"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
from core.helper.provider_cache import TriggerProviderCredentialCache, TriggerProviderOAuthClientCache
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.trigger import TriggerProvider
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter_for_credential(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
credential: TriggerProvider | TriggerProviderCredentialApiEntity,
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(credential.credential_type),
|
||||
cache=TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential.id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(credential_type),
|
||||
cache=TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_trigger_provider_oauth_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
|
||||
cache=TriggerProviderOAuthClientCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
),
|
||||
)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
from datetime import UTC, datetime
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
|
@ -7,6 +8,7 @@ from sqlalchemy import DateTime, Index, Integer, String, Text, func
|
|||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
|
||||
from models.base import Base
|
||||
from models.types import StringUUID
|
||||
|
||||
|
|
@ -45,20 +47,23 @@ class TriggerProvider(Base):
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def credentials_str(self) -> str:
|
||||
"""Get credentials as string"""
|
||||
return self.encrypted_credentials or "{}"
|
||||
|
||||
def is_oauth_expired(self) -> bool:
|
||||
"""Check if OAuth token is expired"""
|
||||
if self.credential_type != CredentialType.OAUTH2.value:
|
||||
return False
|
||||
if self.expires_at == -1:
|
||||
return False
|
||||
# Check if token expires in next 60 seconds
|
||||
return (self.expires_at - 60) < int(datetime.now(UTC).timestamp())
|
||||
# Check if token expires in next 3 minutes
|
||||
return (self.expires_at - 180) < int(time.time())
|
||||
|
||||
def to_api_entity(self) -> TriggerProviderCredentialApiEntity:
|
||||
return TriggerProviderCredentialApiEntity(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
provider=self.provider_id,
|
||||
credential_type=CredentialType(self.credential_type),
|
||||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
# system level trigger oauth client params
|
||||
class TriggerOAuthSystemClient(Base):
|
||||
|
|
|
|||
|
|
@ -13,14 +13,20 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
|||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.service import PluginService
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.tools.utils.encryption import (
|
||||
create_provider_encrypter,
|
||||
)
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import (
|
||||
create_trigger_provider_encrypter_for_credential,
|
||||
create_trigger_provider_oauth_encrypter,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,13 +37,34 @@ class TriggerProviderService:
|
|||
__MAX_TRIGGER_PROVIDER_COUNT__ = 100
|
||||
|
||||
@classmethod
|
||||
def list_trigger_providers(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerProvider]:
|
||||
def list_trigger_providers(cls, tenant_id: str) -> list[TriggerProviderApiEntity]:
|
||||
"""List all trigger providers for the current tenant"""
|
||||
# TODO fetch trigger plugin controller
|
||||
return [provider.to_api_entity() for provider in TriggerManager.list_all_trigger_providers(tenant_id)]
|
||||
|
||||
# TODO fetch all trigger plugin credentials
|
||||
@classmethod
|
||||
def list_trigger_provider_credentials(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> list[TriggerProviderCredentialApiEntity]:
|
||||
"""List all trigger providers for the current tenant"""
|
||||
credentials: list[TriggerProviderCredentialApiEntity] = []
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
return session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).all()
|
||||
credentials_db = (
|
||||
session.query(TriggerProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=str(provider_id))
|
||||
.order_by(desc(TriggerProvider.created_at))
|
||||
.all()
|
||||
)
|
||||
credentials = [credential.to_api_entity() for credential in credentials_db]
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
for credential in credentials:
|
||||
encrypter, _ = create_trigger_provider_encrypter_for_credential(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
credential=credential,
|
||||
)
|
||||
credential.credentials = encrypter.decrypt(credential.credentials)
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def add_trigger_provider(
|
||||
|
|
@ -63,6 +90,7 @@ class TriggerProviderService:
|
|||
:return: Success response
|
||||
"""
|
||||
try:
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine) as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
|
|
@ -96,10 +124,9 @@ class TriggerProviderService:
|
|||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
|
||||
# Create encrypter for credentials
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[], # We'll define schema later in TriggerProvider classes
|
||||
config=provider_controller.get_credential_schema_config(credential_type),
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
|
|
@ -141,20 +168,21 @@ class TriggerProviderService:
|
|||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Get provider
|
||||
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
|
||||
try:
|
||||
# Update credentials if provided
|
||||
if credentials:
|
||||
encrypter, cache = cls._create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
provider=db_provider,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(db_provider.provider_id)
|
||||
)
|
||||
|
||||
if credentials:
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_credential(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
credential=db_provider,
|
||||
)
|
||||
# Handle hidden values
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
new_credentials = {
|
||||
|
|
@ -200,14 +228,20 @@ class TriggerProviderService:
|
|||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
|
||||
# Delete provider
|
||||
provider_controller = TriggerManager.get_trigger_provider(
|
||||
tenant_id, TriggerProviderID(db_provider.provider_id)
|
||||
)
|
||||
# Clear cache
|
||||
_, cache = create_trigger_provider_encrypter_for_credential(
|
||||
tenant_id=tenant_id,
|
||||
controller=provider_controller,
|
||||
credential=db_provider,
|
||||
)
|
||||
|
||||
session.delete(db_provider)
|
||||
session.commit()
|
||||
|
||||
# Clear cache
|
||||
_, cache = cls._create_provider_encrypter(tenant_id, db_provider)
|
||||
cache.delete()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -232,13 +266,13 @@ class TriggerProviderService:
|
|||
if db_provider.credential_type != CredentialType.OAUTH2.value:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(db_provider.provider_id)
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
# Create encrypter
|
||||
encrypter, cache = cls._create_provider_encrypter(
|
||||
encrypter, cache = create_trigger_provider_encrypter_for_credential(
|
||||
tenant_id=tenant_id,
|
||||
provider=db_provider,
|
||||
controller=provider_controller,
|
||||
credential=db_provider,
|
||||
)
|
||||
|
||||
# Decrypt current credentials
|
||||
|
|
@ -285,18 +319,8 @@ class TriggerProviderService:
|
|||
:param provider_id: Provider identifier
|
||||
:return: OAuth client configuration or None
|
||||
"""
|
||||
# Get trigger provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
|
||||
# Create encrypter for OAuth client params
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
# First check for tenant-specific OAuth client
|
||||
tenant_client: TriggerOAuthTenantClient | None = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
|
|
@ -310,10 +334,10 @@ class TriggerProviderService:
|
|||
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
if tenant_client:
|
||||
encrypter, _ = create_trigger_provider_oauth_encrypter(tenant_id, provider_controller)
|
||||
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
|
||||
return oauth_params
|
||||
|
||||
# Only verified plugins can use system OAuth client
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||
if not is_verified:
|
||||
return oauth_params
|
||||
|
|
@ -354,7 +378,7 @@ class TriggerProviderService:
|
|||
return {"result": "success"}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find existing custom client params
|
||||
|
|
@ -425,7 +449,7 @@ class TriggerProviderService:
|
|||
return {}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id)
|
||||
|
||||
# Create encrypter to decrypt and mask values
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
|
|
@ -477,63 +501,6 @@ class TriggerProviderService:
|
|||
)
|
||||
return custom_client is not None
|
||||
|
||||
@classmethod
|
||||
def create_oauth_proxy_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
) -> str:
|
||||
"""
|
||||
Create OAuth proxy context for authorization flow.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider: Provider identifier
|
||||
:return: Context ID for OAuth flow
|
||||
"""
|
||||
return OAuthProxyService.create_proxy_context(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create_provider_encrypter(
|
||||
cls, tenant_id: str, provider: TriggerProvider
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
"""
|
||||
Create encrypter and cache for trigger provider credentials
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider: TriggerProvider instance
|
||||
:return: Tuple of encrypter and cache
|
||||
"""
|
||||
from core.helper.provider_cache import TriggerProviderCredentialCache
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider.provider_id)
|
||||
|
||||
# Get trigger provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
|
||||
# Create encrypter with appropriate schema based on credential type
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(provider.credential_type)
|
||||
],
|
||||
cache=TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider.provider_id,
|
||||
credential_id=provider.id,
|
||||
),
|
||||
)
|
||||
|
||||
return encrypter, cache
|
||||
|
||||
@classmethod
|
||||
def _generate_provider_name(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
|
|
@ -21,3 +20,4 @@ class TriggerService:
|
|||
# TODO dispatch by the trigger controller
|
||||
|
||||
# TODO using the dispatch result(events) to invoke the trigger events
|
||||
raise NotImplementedError("Not implemented")
|
||||
Loading…
Reference in New Issue