mirror of https://github.com/langgenius/dify.git
feat(trigger): add trigger provider management and webhook handling functionality
This commit is contained in:
parent
7544b5ec9a
commit
87120ad4ac
|
|
@ -59,6 +59,7 @@ pnpm test # Run Jest tests
|
|||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
|
|
|
|||
|
|
@ -1207,6 +1207,55 @@ def setup_system_tool_oauth_client(provider, client_params):
|
|||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError
|
|||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
|
|
|||
|
|
@ -0,0 +1,323 @@
|
|||
import logging
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return TriggerProviderService.list_trigger_providers(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerProviderCredentialsAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new credential instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("expires_at", type=int, required=False, nullable=True, location="json", default=-1)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Parse credential type
|
||||
try:
|
||||
credential_type = CredentialType(args["credential_type"])
|
||||
except ValueError:
|
||||
raise BadRequest(f"Invalid credential_type. Must be one of: {[t.value for t in CredentialType]}")
|
||||
|
||||
result = TriggerProviderService.add_trigger_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
credentials=args["credentials"],
|
||||
name=args.get("name"),
|
||||
expires_at=args.get("expires_at", -1),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerProviderCredentialsUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, credential_id):
|
||||
"""Update an existing credential instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
result = TriggerProviderService.update_trigger_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
credentials=args.get("credentials"),
|
||||
name=args.get("name"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerProviderCredentialsDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, credential_id):
|
||||
"""Delete a credential instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
result = TriggerProviderService.delete_trigger_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerProviderOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
try:
|
||||
context_id = TriggerProviderService.create_oauth_proxy_context(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
)
|
||||
|
||||
# TODO: Build OAuth authorization URL
|
||||
# This will be implemented when we have provider-specific OAuth configs
|
||||
|
||||
return {
|
||||
"context_id": context_id,
|
||||
"authorization_url": f"/oauth/authorize?context={context_id}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerProviderOAuthRefreshTokenApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, credential_id):
|
||||
"""Refresh OAuth token for a trigger provider credential"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
result = TriggerProviderService.refresh_oauth_token(
|
||||
tenant_id=user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error refreshing OAuth token", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerProviderOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if there's a system OAuth client
|
||||
system_client = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"configured": bool(custom_params or system_client),
|
||||
"custom_configured": bool(custom_params),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"params": custom_params if custom_params else {},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
result = TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
result = TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger provider endpoints
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-provider/<path:provider>/list")
|
||||
api.add_resource(TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider/<path:provider>/add")
|
||||
api.add_resource(
|
||||
TriggerProviderCredentialsUpdateApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/update"
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderCredentialsDeleteApi, "/workspaces/current/trigger-provider/credentials/<path:credential_id>/delete"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/authorize"
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthRefreshTokenApi,
|
||||
"/workspaces/current/trigger-provider/credentials/<path:credential_id>/oauth/refresh",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerProviderOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client"
|
||||
)
|
||||
|
|
@ -4,4 +4,4 @@ from flask import Blueprint
|
|||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import webhook
|
||||
from . import trigger, webhook
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
import logging
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger_service import TriggerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@bp.route("/trigger/webhook/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_webhook(endpoint_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
"""
|
||||
try:
|
||||
return TriggerService.process_webhook(endpoint_id, request)
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
|
|
@ -68,6 +68,19 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
|||
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class TriggerProviderCredentialCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
|
|
|
|||
|
|
@ -184,6 +184,10 @@ class ToolProviderID(GenericProviderID):
|
|||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class TriggerProviderID(GenericProviderID):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDependency(BaseModel):
|
||||
class Type(enum.StrEnum):
|
||||
Github = PluginInstallationSource.Github.value
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import enum
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
|
@ -13,6 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption
|
|||
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
|
||||
from core.trigger.entities import TriggerProviderEntity
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
|
@ -196,3 +198,43 @@ class PluginListResponse(BaseModel):
|
|||
|
||||
class PluginDynamicSelectOptionsResponse(BaseModel):
|
||||
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
|
||||
|
||||
|
||||
class PluginTriggerProviderEntity(BaseModel):
|
||||
provider: str
|
||||
plugin_unique_identifier: str
|
||||
plugin_id: str
|
||||
declaration: TriggerProviderEntity
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ from typing import Any, Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from core.plugin.utils.chunk_merger import merge_blob_chunks
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginClient):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from typing import Any
|
||||
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity, PluginTriggerProviderEntity
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
|
||||
|
||||
class PluginTriggerManager(BasePluginClient):
|
||||
def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]:
|
||||
"""
|
||||
Fetch tool providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tools",
|
||||
list[PluginToolProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = tool_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tool",
|
||||
PluginToolProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
|
@ -4,7 +4,8 @@ from openai import BaseModel
|
|||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||
|
||||
|
||||
class ToolRuntime(BaseModel):
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ from typing import Any
|
|||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import (
|
||||
CredentialType,
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderEntity,
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from typing import Any, Literal, Optional
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class ToolApiEntity(BaseModel):
|
||||
|
|
|
|||
|
|
@ -476,36 +476,3 @@ class ToolSelector(BaseModel):
|
|||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
return "API KEY"
|
||||
elif self == CredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
def is_validate_allowed(self):
|
||||
return self == CredentialType.API_KEY
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [item.value for item in cls]
|
||||
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
|
|
@ -47,7 +48,6 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
|
|||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
CredentialType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
# Core trigger module initialization
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class TriggerParameterOption(BaseModel):
|
||||
"""
|
||||
The option of the trigger parameter
|
||||
"""
|
||||
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
|
||||
class TriggerParameterType(StrEnum):
|
||||
"""The type of the parameter"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
FILE = "file"
|
||||
FILES = "files"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
APP_SELECTOR = "app-selector"
|
||||
OBJECT = "object"
|
||||
ARRAY = "array"
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
|
||||
class ParameterAutoGenerate(BaseModel):
|
||||
"""Auto generation configuration for parameters"""
|
||||
|
||||
enabled: bool = Field(default=False, description="Whether auto generation is enabled")
|
||||
template: Optional[str] = Field(default=None, description="Template for auto generation")
|
||||
|
||||
|
||||
class ParameterTemplate(BaseModel):
|
||||
"""Template configuration for parameters"""
|
||||
|
||||
value: str = Field(..., description="Template value")
|
||||
type: str = Field(default="jinja2", description="Template type")
|
||||
|
||||
|
||||
class TriggerParameter(BaseModel):
|
||||
"""
|
||||
The parameter of the trigger
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
type: TriggerParameterType = Field(..., description="The type of the parameter")
|
||||
auto_generate: Optional[ParameterAutoGenerate] = Field(
|
||||
default=None, description="The auto generate of the parameter"
|
||||
)
|
||||
template: Optional[ParameterTemplate] = Field(default=None, description="The template of the parameter")
|
||||
scope: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
default: Union[int, float, str, None] = None
|
||||
min: Union[float, int, None] = None
|
||||
max: Union[float, int, None] = None
|
||||
precision: Optional[int] = None
|
||||
options: Optional[list[TriggerParameterOption]] = None
|
||||
description: Optional[I18nObject] = None
|
||||
|
||||
|
||||
class TriggerProviderIdentity(BaseModel):
|
||||
"""
|
||||
The identity of the trigger provider
|
||||
"""
|
||||
|
||||
author: str = Field(..., description="The author of the trigger provider")
|
||||
name: str = Field(..., description="The name of the trigger provider")
|
||||
label: I18nObject = Field(..., description="The label of the trigger provider")
|
||||
description: I18nObject = Field(..., description="The description of the trigger provider")
|
||||
icon: Optional[str] = Field(default=None, description="The icon of the trigger provider")
|
||||
tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider")
|
||||
|
||||
|
||||
class TriggerIdentity(BaseModel):
|
||||
"""
|
||||
The identity of the trigger
|
||||
"""
|
||||
|
||||
author: str = Field(..., description="The author of the trigger")
|
||||
name: str = Field(..., description="The name of the trigger")
|
||||
label: I18nObject = Field(..., description="The label of the trigger")
|
||||
|
||||
|
||||
class TriggerDescription(BaseModel):
|
||||
"""
|
||||
The description of the trigger
|
||||
"""
|
||||
|
||||
human: I18nObject = Field(..., description="Human readable description")
|
||||
llm: I18nObject = Field(..., description="LLM readable description")
|
||||
|
||||
|
||||
class TriggerConfigurationExtraPython(BaseModel):
|
||||
"""Python configuration for trigger"""
|
||||
|
||||
source: str = Field(..., description="The source file path for the trigger implementation")
|
||||
|
||||
|
||||
class TriggerConfigurationExtra(BaseModel):
|
||||
"""
|
||||
The extra configuration for trigger
|
||||
"""
|
||||
|
||||
|
||||
class TriggerEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger
|
||||
"""
|
||||
|
||||
python: TriggerConfigurationExtraPython
|
||||
identity: TriggerIdentity = Field(..., description="The identity of the trigger")
|
||||
parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger")
|
||||
description: TriggerDescription = Field(..., description="The description of the trigger")
|
||||
extra: TriggerConfigurationExtra = Field(..., description="The extra configuration of the trigger")
|
||||
output_schema: Optional[Mapping[str, Any]] = Field(
|
||||
default=None, description="The output schema that this trigger produces"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProviderConfigurationExtraPython(BaseModel):
|
||||
"""Python configuration for trigger provider"""
|
||||
|
||||
source: str = Field(..., description="The source file path for the trigger provider implementation")
|
||||
|
||||
|
||||
class TriggerProviderConfigurationExtra(BaseModel):
|
||||
"""
|
||||
The extra configuration for trigger provider
|
||||
"""
|
||||
|
||||
python: TriggerProviderConfigurationExtraPython
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
"""OAuth configuration schema"""
|
||||
|
||||
authorization_url: str = Field(..., description="OAuth authorization URL")
|
||||
token_url: str = Field(..., description="OAuth token URL")
|
||||
client_id: str = Field(..., description="OAuth client ID")
|
||||
client_secret: str = Field(..., description="OAuth client secret")
|
||||
redirect_uri: Optional[str] = Field(default=None, description="OAuth redirect URI")
|
||||
scope: Optional[str] = Field(default=None, description="OAuth scope")
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""Provider configuration item"""
|
||||
|
||||
name: str = Field(..., description="Configuration field name")
|
||||
type: str = Field(..., description="Configuration field type")
|
||||
required: bool = Field(default=False, description="Whether this field is required")
|
||||
default: Any = Field(default=None, description="Default value")
|
||||
label: Optional[I18nObject] = Field(default=None, description="Field label")
|
||||
description: Optional[I18nObject] = Field(default=None, description="Field description")
|
||||
options: Optional[list[dict[str, Any]]] = Field(default=None, description="Options for select type")
|
||||
|
||||
|
||||
class TriggerProviderEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger provider
|
||||
"""
|
||||
|
||||
identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="The credentials schema of the trigger provider",
|
||||
)
|
||||
oauth_schema: Optional[OAuthSchema] = Field(
|
||||
default=None,
|
||||
description="The OAuth schema of the trigger provider if OAuth is supported",
|
||||
)
|
||||
subscription_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list,
|
||||
description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters",
|
||||
)
|
||||
triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider")
|
||||
extra: TriggerProviderConfigurationExtra = Field(..., description="The extra configuration of the trigger provider")
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
"""
|
||||
Result of a successful trigger subscription operation.
|
||||
|
||||
Contains all information needed to manage the subscription lifecycle.
|
||||
"""
|
||||
|
||||
expire_at: int = Field(
|
||||
..., description="The timestamp when the subscription will expire, this for refresh the subscription"
|
||||
)
|
||||
|
||||
metadata: dict[str, Any] = Field(
|
||||
..., description="Metadata about the subscription in the external service, defined in subscription_schema"
|
||||
)
|
||||
|
||||
|
||||
class Unsubscription(BaseModel):
|
||||
"""
|
||||
Result of a trigger unsubscription operation.
|
||||
|
||||
Provides detailed information about the unsubscription attempt,
|
||||
including success status and error details if failed.
|
||||
"""
|
||||
|
||||
success: bool = Field(..., description="Whether the unsubscription was successful")
|
||||
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="Human-readable message about the operation result. "
|
||||
"Success message for successful operations, "
|
||||
"detailed error information for failures.",
|
||||
)
|
||||
|
||||
|
||||
# Export all entities
|
||||
__all__ = [
|
||||
"OAuthSchema",
|
||||
"ParameterAutoGenerate",
|
||||
"ParameterTemplate",
|
||||
"ProviderConfig",
|
||||
"Subscription",
|
||||
"TriggerConfigurationExtra",
|
||||
"TriggerConfigurationExtraPython",
|
||||
"TriggerDescription",
|
||||
"TriggerEntity",
|
||||
"TriggerEntity",
|
||||
"TriggerIdentity",
|
||||
"TriggerParameter",
|
||||
"TriggerParameterOption",
|
||||
"TriggerParameterType",
|
||||
"TriggerProviderConfigurationExtra",
|
||||
"TriggerProviderConfigurationExtraPython",
|
||||
"TriggerProviderEntity",
|
||||
"TriggerProviderIdentity",
|
||||
"Unsubscription",
|
||||
]
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
"""
|
||||
Trigger Provider Controller for managing trigger providers
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities import (
|
||||
ProviderConfig,
|
||||
Subscription,
|
||||
TriggerEntity,
|
||||
TriggerProviderEntity,
|
||||
TriggerProviderIdentity,
|
||||
Unsubscription,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderController:
|
||||
"""
|
||||
Controller for plugin trigger providers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: TriggerProviderEntity,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Initialize plugin trigger provider controller
|
||||
|
||||
:param entity: Trigger provider entity
|
||||
:param plugin_id: Plugin ID
|
||||
:param plugin_unique_identifier: Plugin unique identifier
|
||||
:param tenant_id: Tenant ID
|
||||
"""
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def identity(self) -> TriggerProviderIdentity:
|
||||
"""Get provider identity"""
|
||||
return self.entity.identity
|
||||
|
||||
def get_triggers(self) -> list[TriggerEntity]:
|
||||
"""
|
||||
Get all triggers for this provider
|
||||
|
||||
:return: List of trigger entities
|
||||
"""
|
||||
return self.entity.triggers
|
||||
|
||||
def get_trigger(self, trigger_name: str) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger by name
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:return: Trigger entity or None
|
||||
"""
|
||||
for trigger in self.entity.triggers:
|
||||
if trigger.identity.name == trigger_name:
|
||||
return trigger
|
||||
return None
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get credentials schema for this provider
|
||||
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
return self.entity.credentials_schema
|
||||
|
||||
def get_subscription_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get subscription schema for this provider
|
||||
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
return self.entity.subscription_schema
|
||||
|
||||
def validate_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate credentials against schema
|
||||
|
||||
:param credentials: Credentials to validate
|
||||
:raises ValueError: If credentials are invalid
|
||||
"""
|
||||
for config in self.entity.credentials_schema:
|
||||
if config.required and config.name not in credentials:
|
||||
raise ValueError(f"Missing required credential field: {config.name}")
|
||||
|
||||
def get_supported_credential_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
Get supported credential types for this provider.
|
||||
|
||||
:return: List of supported credential types
|
||||
"""
|
||||
types = []
|
||||
if self.entity.oauth_schema:
|
||||
types.append(CredentialType.OAUTH2)
|
||||
if self.entity.credentials_schema:
|
||||
types.append(CredentialType.API_KEY)
|
||||
return types
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get credentials schema by credential type
|
||||
|
||||
:param credential_type: The type of credential (oauth or api_key)
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY.value:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get OAuth client schema for this provider
|
||||
|
||||
:return: List of OAuth client config schemas
|
||||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""Check if this provider needs credentials"""
|
||||
return len(self.get_supported_credential_types()) > 0
|
||||
|
||||
def execute_trigger(self, trigger_name: str, parameters: dict, credentials: dict) -> dict:
|
||||
"""
|
||||
Execute a trigger through plugin runtime
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Execution result
|
||||
"""
|
||||
logger.info("Executing trigger %s for plugin %s", trigger_name, self.plugin_id)
|
||||
return {
|
||||
"success": True,
|
||||
"trigger": trigger_name,
|
||||
"plugin": self.plugin_id,
|
||||
"result": "Trigger executed successfully",
|
||||
}
|
||||
|
||||
def subscribe_trigger(self, trigger_name: str, subscription_params: dict, credentials: dict) -> Subscription:
|
||||
"""
|
||||
Subscribe to a trigger through plugin runtime
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_params: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
logger.info("Subscribing to trigger %s for plugin %s", trigger_name, self.plugin_id)
|
||||
return Subscription(
|
||||
expire_at=int(time.time()) + 86400, # 24 hours from now
|
||||
metadata={
|
||||
"subscription_id": f"{self.plugin_id}_{trigger_name}_{time.time()}",
|
||||
"webhook_url": f"/triggers/webhook/{self.plugin_id}/{trigger_name}",
|
||||
**subscription_params,
|
||||
},
|
||||
)
|
||||
|
||||
def unsubscribe_trigger(self, trigger_name: str, subscription_metadata: dict, credentials: dict) -> Unsubscription:
|
||||
"""
|
||||
Unsubscribe from a trigger through plugin runtime
|
||||
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_metadata: Subscription metadata
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id)
|
||||
return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}")
|
||||
|
||||
def handle_webhook(self, webhook_path: str, request_data: dict, credentials: dict) -> dict:
|
||||
"""
|
||||
Handle incoming webhook through plugin runtime
|
||||
|
||||
:param webhook_path: Webhook path
|
||||
:param request_data: Request data
|
||||
:param credentials: Provider credentials
|
||||
:return: Webhook handling result
|
||||
"""
|
||||
logger.info("Handling webhook for path %s for plugin %s", webhook_path, self.plugin_id)
|
||||
return {"success": True, "path": webhook_path, "plugin": self.plugin_id, "data_received": request_data}
|
||||
|
||||
|
||||
__all__ = ["TriggerProviderController"]
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
"""
|
||||
Trigger Manager for loading and managing trigger providers and triggers
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.trigger.entities import (
|
||||
ProviderConfig,
|
||||
TriggerEntity,
|
||||
)
|
||||
from core.trigger.plugin_trigger import PluginTriggerController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerManager:
|
||||
"""
|
||||
Manager for trigger providers and triggers
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def list_plugin_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||
"""
|
||||
List all plugin trigger providers for a tenant
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of trigger provider controllers
|
||||
"""
|
||||
manager = PluginTriggerController()
|
||||
provider_entities = manager.fetch_trigger_providers(tenant_id)
|
||||
|
||||
controllers = []
|
||||
for provider in provider_entities:
|
||||
try:
|
||||
controller = PluginTriggerProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
controllers.append(controller)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to load trigger provider {provider.plugin_id}")
|
||||
continue
|
||||
|
||||
return controllers
|
||||
|
||||
@classmethod
|
||||
def get_plugin_trigger_provider(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str
|
||||
) -> Optional[PluginTriggerProviderController]:
|
||||
"""
|
||||
Get a specific plugin trigger provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:return: Trigger provider controller or None
|
||||
"""
|
||||
manager = PluginTriggerManager()
|
||||
provider = manager.fetch_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
try:
|
||||
return PluginTriggerProviderController(
|
||||
entity=provider.declaration,
|
||||
plugin_id=provider.plugin_id,
|
||||
plugin_unique_identifier=provider.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to load trigger provider")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]:
|
||||
"""
|
||||
List all trigger providers (plugin and builtin)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of all trigger provider controllers
|
||||
"""
|
||||
providers = []
|
||||
|
||||
# Get plugin providers
|
||||
plugin_providers = cls.list_plugin_trigger_providers(tenant_id)
|
||||
providers.extend(plugin_providers)
|
||||
|
||||
# TODO: Add builtin providers when implemented
|
||||
# builtin_providers = cls.list_builtin_trigger_providers(tenant_id)
|
||||
# providers.extend(builtin_providers)
|
||||
|
||||
return providers
|
||||
|
||||
@classmethod
|
||||
def list_triggers_by_provider(cls, tenant_id: str, plugin_id: str, provider_name: str) -> list[TriggerEntity]:
|
||||
"""
|
||||
List all triggers for a specific provider
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:return: List of trigger entities
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return []
|
||||
|
||||
return provider.get_triggers()
|
||||
|
||||
@classmethod
|
||||
def get_trigger(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str
|
||||
) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param trigger_name: Trigger name
|
||||
:return: Trigger entity or None
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
return provider.get_trigger(trigger_name)
|
||||
|
||||
@classmethod
|
||||
def validate_trigger_credentials(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str, credentials: dict
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate trigger provider credentials
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param credentials: Credentials to validate
|
||||
:return: Tuple of (is_valid, error_message)
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return False, "Provider not found"
|
||||
|
||||
try:
|
||||
provider.validate_credentials(credentials)
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
@classmethod
|
||||
def execute_trigger(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str, parameters: dict, credentials: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Execute a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param trigger_name: Trigger name
|
||||
:param parameters: Trigger parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Trigger execution result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
trigger = provider.get_trigger(trigger_name)
|
||||
if not trigger:
|
||||
raise ValueError(f"Trigger {trigger_name} not found in provider {provider_name}")
|
||||
|
||||
return provider.execute_trigger(trigger_name, parameters, credentials)
|
||||
|
||||
@classmethod
|
||||
def subscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider_name: str,
|
||||
trigger_name: str,
|
||||
subscription_params: dict,
|
||||
credentials: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Subscribe to a trigger (e.g., register webhook)
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_params: Subscription parameters
|
||||
:param credentials: Provider credentials
|
||||
:return: Subscription result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
return provider.subscribe_trigger(trigger_name, subscription_params, credentials)
|
||||
|
||||
@classmethod
|
||||
def unsubscribe_trigger(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider_name: str,
|
||||
trigger_name: str,
|
||||
subscription_metadata: dict,
|
||||
credentials: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Unsubscribe from a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param trigger_name: Trigger name
|
||||
:param subscription_metadata: Subscription metadata from subscribe operation
|
||||
:param credentials: Provider credentials
|
||||
:return: Unsubscription result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
return provider.unsubscribe_trigger(trigger_name, subscription_metadata, credentials)
|
||||
|
||||
@classmethod
|
||||
def handle_webhook(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider_name: str,
|
||||
webhook_path: str,
|
||||
request_data: dict,
|
||||
credentials: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Handle incoming webhook for a trigger
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:param webhook_path: Webhook path
|
||||
:param request_data: Webhook request data
|
||||
:param credentials: Provider credentials
|
||||
:return: Webhook handling result
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {plugin_id}/{provider_name} not found")
|
||||
|
||||
return provider.handle_webhook(webhook_path, request_data, credentials)
|
||||
|
||||
@classmethod
|
||||
def get_provider_credentials_schema(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str
|
||||
) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get provider credentials schema
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:return: List of provider config schemas
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return []
|
||||
|
||||
return provider.get_credentials_schema()
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(
|
||||
cls, tenant_id: str, plugin_id: str, provider_name: str
|
||||
) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get provider subscription schema
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:return: List of subscription config schemas
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return []
|
||||
|
||||
return provider.get_subscription_schema()
|
||||
|
||||
@classmethod
|
||||
def get_provider_info(cls, tenant_id: str, plugin_id: str, provider_name: str) -> Optional[dict]:
|
||||
"""
|
||||
Get provider information
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param plugin_id: Plugin ID
|
||||
:param provider_name: Provider name
|
||||
:return: Provider info dict or None
|
||||
"""
|
||||
provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name)
|
||||
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
return {
|
||||
"plugin_id": plugin_id,
|
||||
"provider_name": provider_name,
|
||||
"identity": provider.entity.identity.model_dump() if provider.entity.identity else {},
|
||||
"credentials_schema": [c.model_dump() for c in provider.entity.credentials_schema],
|
||||
"subscription_schema": [s.model_dump() for s in provider.entity.subscription_schema],
|
||||
"oauth_enabled": provider.entity.oauth_schema is not None,
|
||||
"trigger_count": len(provider.entity.triggers),
|
||||
"triggers": [t.identity.model_dump() for t in provider.entity.triggers],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_providers_for_workflow(cls, tenant_id: str) -> list[dict]:
|
||||
"""
|
||||
List trigger providers suitable for workflow usage
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:return: List of provider info dicts
|
||||
"""
|
||||
providers = cls.list_all_trigger_providers(tenant_id)
|
||||
|
||||
result = []
|
||||
for provider in providers:
|
||||
info = {
|
||||
"plugin_id": provider.plugin_id,
|
||||
"provider_name": provider.entity.identity.name,
|
||||
"label": provider.entity.identity.label.model_dump(),
|
||||
"description": provider.entity.identity.description.model_dump(),
|
||||
"icon": provider.entity.identity.icon,
|
||||
"trigger_count": len(provider.entity.triggers),
|
||||
}
|
||||
result.append(info)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = ["TriggerManager"]
|
||||
|
|
@ -20,6 +20,7 @@ def init_app(app: DifyApp):
|
|||
reset_encrypt_key_pair,
|
||||
reset_password,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
upgrade_db,
|
||||
vdb_migrate,
|
||||
)
|
||||
|
|
@ -43,6 +44,7 @@ def init_app(app: DifyApp):
|
|||
clear_orphaned_file_records,
|
||||
remove_orphaned_files_on_storage,
|
||||
setup_system_tool_oauth_client,
|
||||
setup_system_trigger_oauth_client,
|
||||
cleanup_orphaned_draft_variables,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, Index, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from models.base import Base
|
||||
from models.types import StringUUID
|
||||
|
||||
|
||||
class TriggerProvider(Base):
|
||||
"""
|
||||
Trigger provider model for managing credentials
|
||||
Supports multiple credential instances per provider
|
||||
"""
|
||||
|
||||
__tablename__ = "trigger_providers"
|
||||
__table_args__ = (Index("idx_trigger_providers_tenant_provider", "tenant_id", "provider_id"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_id: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
encrypted_credentials: Mapped[str] = mapped_column(Text, nullable=False, comment="Encrypted credentials JSON")
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Credential instance name")
|
||||
expires_at: Mapped[int] = mapped_column(
|
||||
Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
"""Get credentials as dict (still encrypted)"""
|
||||
try:
|
||||
return json.loads(self.encrypted_credentials) if self.encrypted_credentials else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def credentials_str(self) -> str:
|
||||
"""Get credentials as string"""
|
||||
return self.encrypted_credentials or "{}"
|
||||
|
||||
def is_oauth_expired(self) -> bool:
|
||||
"""Check if OAuth token is expired"""
|
||||
if self.credential_type != CredentialType.OAUTH2.value:
|
||||
return False
|
||||
if self.expires_at == -1:
|
||||
return False
|
||||
# Check if token expires in next 60 seconds
|
||||
return (self.expires_at - 60) < int(datetime.now(UTC).timestamp())
|
||||
|
||||
|
||||
# system level trigger oauth client params
|
||||
class TriggerOAuthSystemClient(Base):
|
||||
__tablename__ = "trigger_oauth_system_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trigger_oauth_system_client_pkey"),
|
||||
sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
# oauth params of the trigger provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
|
||||
# tenant level trigger oauth client params (client_id, client_secret, etc.)
|
||||
class TriggerOAuthTenantClient(Base):
|
||||
__tablename__ = "trigger_oauth_tenant_clients"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
# tenant id
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
# oauth params of the trigger provider
|
||||
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> dict:
|
||||
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
|
@ -13,6 +13,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
|||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import (
|
||||
|
|
@ -21,7 +22,6 @@ from core.tools.entities.api_entities import (
|
|||
ToolProviderCredentialApiEntity,
|
||||
ToolProviderCredentialInfoApiEntity,
|
||||
)
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
|
|
@ -39,7 +39,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class BuiltinToolManageService:
|
||||
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||
__DEFAULT_EXPIRES_AT__ = 2147483647
|
||||
|
||||
@staticmethod
|
||||
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
|
||||
|
|
@ -278,9 +277,7 @@ class BuiltinToolManageService:
|
|||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
credential_type=api_type.value,
|
||||
name=name,
|
||||
expires_at=expires_at
|
||||
if expires_at is not None
|
||||
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
|
||||
expires_at=expires_at if expires_at is not None else -1,
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from yarl import URL
|
|||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
|
@ -16,7 +17,6 @@ from core.tools.entities.common_entities import I18nObject
|
|||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
CredentialType,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,588 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.service import PluginService
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderService:
|
||||
"""Service for managing trigger providers and credentials"""
|
||||
|
||||
__MAX_TRIGGER_PROVIDER_COUNT__ = 100
|
||||
|
||||
@classmethod
|
||||
def list_trigger_providers(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerProvider]:
|
||||
"""List all trigger providers for the current tenant"""
|
||||
# TODO fetch trigger plugin controller
|
||||
|
||||
# TODO fetch all trigger plugin credentials
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
return session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).all()
|
||||
|
||||
@classmethod
|
||||
def add_trigger_provider(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credential_type: CredentialType,
|
||||
credentials: dict,
|
||||
name: Optional[str] = None,
|
||||
expires_at: int = -1,
|
||||
) -> dict:
|
||||
"""
|
||||
Add a new trigger provider with credentials.
|
||||
Supports multiple credential instances per provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier (e.g., "plugin_id/provider_name")
|
||||
:param credential_type: Type of credential (oauth or api_key)
|
||||
:param credentials: Credential data to encrypt and store
|
||||
:param name: Optional name for this credential instance
|
||||
:param expires_at: OAuth token expiration timestamp
|
||||
:return: Success response
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Use distributed lock to prevent race conditions
|
||||
lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}"
|
||||
with redis_client.lock(lock_key, timeout=20):
|
||||
# Check provider count limit
|
||||
provider_count = (
|
||||
session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).count()
|
||||
)
|
||||
|
||||
if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__:
|
||||
raise ValueError(
|
||||
f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) "
|
||||
f"reached for {provider_id}"
|
||||
)
|
||||
|
||||
# Generate name if not provided
|
||||
if not name:
|
||||
name = cls._generate_provider_name(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists for this provider")
|
||||
|
||||
# Create encrypter for credentials
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[], # We'll define schema later in TriggerProvider classes
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Create provider record
|
||||
db_provider = TriggerProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
|
||||
return {"result": "success", "id": str(db_provider.id)}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add trigger provider")
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def update_trigger_provider(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credential_id: str,
|
||||
credentials: Optional[dict] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update an existing trigger provider's credentials or name.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param credential_id: Credential instance ID
|
||||
:param credentials: New credentials (optional)
|
||||
:param name: New name (optional)
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Get provider
|
||||
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
|
||||
try:
|
||||
# Update credentials if provided
|
||||
if credentials:
|
||||
encrypter, cache = cls._create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
provider=db_provider,
|
||||
)
|
||||
|
||||
# Handle hidden values
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
new_credentials = {
|
||||
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
||||
for key, value in credentials.items()
|
||||
}
|
||||
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
|
||||
cache.delete()
|
||||
|
||||
# Update name if provided
|
||||
if name and name != db_provider.name:
|
||||
# Check if name already exists
|
||||
existing = (
|
||||
session.query(TriggerProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name)
|
||||
.filter(TriggerProvider.id != credential_id)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Credential name '{name}' already exists")
|
||||
|
||||
db_provider.name = name
|
||||
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
@classmethod
|
||||
def delete_trigger_provider(cls, tenant_id: str, credential_id: str) -> dict:
|
||||
"""
|
||||
Delete a trigger provider credential.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param credential_id: Credential instance ID
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
|
||||
# Delete provider
|
||||
session.delete(db_provider)
|
||||
session.commit()
|
||||
|
||||
# Clear cache
|
||||
_, cache = cls._create_provider_encrypter(tenant_id, db_provider)
|
||||
cache.delete()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def refresh_oauth_token(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credential_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Refresh OAuth token for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param credential_id: Credential instance ID
|
||||
:return: New token info
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
||||
|
||||
if not db_provider:
|
||||
raise ValueError(f"Trigger provider credential {credential_id} not found")
|
||||
|
||||
if db_provider.credential_type != CredentialType.OAUTH2.value:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(db_provider.provider_id)
|
||||
|
||||
# Create encrypter
|
||||
encrypter, cache = cls._create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
provider=db_provider,
|
||||
)
|
||||
|
||||
# Decrypt current credentials
|
||||
current_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
|
||||
# Get OAuth client configuration
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback"
|
||||
)
|
||||
system_credentials = cls.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
# Refresh token
|
||||
oauth_handler = OAuthHandler()
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=db_provider.user_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=system_credentials or {},
|
||||
credentials=current_credentials,
|
||||
)
|
||||
|
||||
# Update credentials
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(dict(refreshed_credentials.credentials)))
|
||||
db_provider.expires_at = refreshed_credentials.expires_at
|
||||
session.commit()
|
||||
|
||||
# Clear cache
|
||||
cache.delete()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"expires_at": refreshed_credentials.expires_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]:
|
||||
"""
|
||||
Get OAuth client configuration for a provider.
|
||||
First tries tenant-level OAuth, then falls back to system OAuth.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: OAuth client configuration or None
|
||||
"""
|
||||
# Get trigger provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
|
||||
# Create encrypter for OAuth client params
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
# First check for tenant-specific OAuth client
|
||||
tenant_client: TriggerOAuthTenantClient | None = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
oauth_params: Mapping[str, Any] | None = None
|
||||
if tenant_client:
|
||||
oauth_params = encrypter.decrypt(tenant_client.oauth_params)
|
||||
return oauth_params
|
||||
|
||||
# Only verified plugins can use system OAuth client
|
||||
is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id)
|
||||
if not is_verified:
|
||||
return oauth_params
|
||||
|
||||
# Check for system-level OAuth client
|
||||
system_client: TriggerOAuthSystemClient | None = (
|
||||
session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if system_client:
|
||||
try:
|
||||
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error decrypting system oauth params: {e}")
|
||||
|
||||
return oauth_params
|
||||
|
||||
@classmethod
|
||||
def save_custom_oauth_client_params(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
client_params: Optional[dict] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Save or update custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:param client_params: OAuth client parameters (client_id, client_secret, etc.)
|
||||
:param enabled: Enable/disable the custom OAuth client
|
||||
:return: Success response
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return {"result": "success"}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find existing custom client params
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Create new record if doesn't exist
|
||||
if custom_client is None:
|
||||
custom_client = TriggerOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
session.add(custom_client)
|
||||
|
||||
# Update client params if provided
|
||||
if client_params is not None:
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Handle hidden values
|
||||
original_params = encrypter.decrypt(custom_client.oauth_params)
|
||||
new_params: dict = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
|
||||
|
||||
# Update enabled status if provided
|
||||
if enabled is not None:
|
||||
custom_client.enabled = enabled
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
||||
"""
|
||||
Get custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: Masked OAuth client parameters
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if custom_client is None:
|
||||
return {}
|
||||
|
||||
# Get provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
|
||||
# Create encrypter to decrypt and mask values
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params))
|
||||
|
||||
@classmethod
|
||||
def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict:
|
||||
"""
|
||||
Delete custom OAuth client parameters for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: Success response
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
session.query(TriggerOAuthTenantClient).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool:
|
||||
"""
|
||||
Check if custom OAuth client is enabled for a trigger provider.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider_id: Provider identifier
|
||||
:return: True if enabled, False otherwise
|
||||
"""
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
custom_client = (
|
||||
session.query(TriggerOAuthTenantClient)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return custom_client is not None
|
||||
|
||||
@classmethod
|
||||
def create_oauth_proxy_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
) -> str:
|
||||
"""
|
||||
Create OAuth proxy context for authorization flow.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param user_id: User ID
|
||||
:param provider: Provider identifier
|
||||
:return: Context ID for OAuth flow
|
||||
"""
|
||||
return OAuthProxyService.create_proxy_context(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
provider=provider_id.provider_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _create_provider_encrypter(
|
||||
cls, tenant_id: str, provider: TriggerProvider
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
"""
|
||||
Create encrypter and cache for trigger provider credentials
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider: TriggerProvider instance
|
||||
:return: Tuple of encrypter and cache
|
||||
"""
|
||||
from core.helper.provider_cache import TriggerProviderCredentialCache
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider.provider_id)
|
||||
|
||||
# Get trigger provider controller to access schema
|
||||
provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id)
|
||||
|
||||
# Create encrypter with appropriate schema based on credential type
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(provider.credential_type)
|
||||
],
|
||||
cache=TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider.provider_id,
|
||||
credential_id=provider.id,
|
||||
),
|
||||
)
|
||||
|
||||
return encrypter, cache
|
||||
|
||||
@classmethod
|
||||
def _generate_provider_name(
|
||||
cls,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
provider_id: TriggerProviderID,
|
||||
credential_type: CredentialType,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique name for a provider credential instance.
|
||||
|
||||
:param session: Database session
|
||||
:param tenant_id: Tenant ID
|
||||
:param provider: Provider identifier
|
||||
:param credential_type: Credential type
|
||||
:return: Generated name
|
||||
"""
|
||||
try:
|
||||
db_providers = (
|
||||
session.query(TriggerProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(desc(TriggerProvider.created_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get base name
|
||||
base_name = credential_type.get_name()
|
||||
|
||||
# Find existing numbered names
|
||||
pattern = rf"^{re.escape(base_name)}\s+(\d+)$"
|
||||
numbers = []
|
||||
|
||||
for db_provider in db_providers:
|
||||
if db_provider.name:
|
||||
match = re.match(pattern, db_provider.name.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
# Generate next number
|
||||
if not numbers:
|
||||
return f"{base_name} 1"
|
||||
|
||||
max_number = max(numbers)
|
||||
return f"{base_name} {max_number + 1}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Error generating provider name")
|
||||
return f"{credential_type.get_name()} 1"
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import Request, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerService:
|
||||
__MAX_REQUEST_LOG_COUNT__ = 10
|
||||
|
||||
@classmethod
|
||||
def process_webhook(cls, webhook_id: str, request: Request) -> Response:
|
||||
"""Extract and process data from incoming webhook request."""
|
||||
# TODO redis slidingwindow log, save the recent request log in redis, rollover the log when the window is full
|
||||
|
||||
# TODO find the trigger subscription
|
||||
|
||||
# TODO fetch the trigger controller
|
||||
|
||||
# TODO dispatch by the trigger controller
|
||||
|
||||
# TODO using the dispatch result(events) to invoke the trigger events
|
||||
|
|
@ -55,9 +55,7 @@ class BaseQueueDispatcher(ABC):
|
|||
True if quota available, False otherwise
|
||||
"""
|
||||
# Check without consuming
|
||||
remaining = self.rate_limiter.get_remaining_quota(
|
||||
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit()
|
||||
)
|
||||
remaining = self.rate_limiter.get_remaining_quota(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
|
||||
return remaining > 0
|
||||
|
||||
def consume_quota(self, tenant_id: str) -> bool:
|
||||
|
|
@ -70,9 +68,7 @@ class BaseQueueDispatcher(ABC):
|
|||
Returns:
|
||||
True if quota consumed successfully, False if limit reached
|
||||
"""
|
||||
return self.rate_limiter.check_and_consume(
|
||||
tenant_id=tenant_id, max_daily_limit=self.get_daily_limit()
|
||||
)
|
||||
return self.rate_limiter.check_and_consume(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit())
|
||||
|
||||
|
||||
class ProfessionalQueueDispatcher(BaseQueueDispatcher):
|
||||
|
|
|
|||
|
|
@ -74,10 +74,10 @@ class TenantDailyRateLimiter:
|
|||
Number of seconds until UTC midnight
|
||||
"""
|
||||
utc_now = datetime.utcnow()
|
||||
|
||||
|
||||
# Get next midnight in UTC
|
||||
next_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
|
||||
|
||||
|
||||
return int((next_midnight - utc_now).total_seconds())
|
||||
|
||||
def check_and_consume(self, tenant_id: str, max_daily_limit: int) -> bool:
|
||||
|
|
@ -174,9 +174,9 @@ class TenantDailyRateLimiter:
|
|||
"""
|
||||
tz = pytz.timezone(timezone_str)
|
||||
utc_now = datetime.utcnow()
|
||||
|
||||
|
||||
# Get next midnight in UTC, then convert to tenant's timezone
|
||||
next_utc_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min)
|
||||
next_utc_midnight = pytz.UTC.localize(next_utc_midnight)
|
||||
|
||||
|
||||
return next_utc_midnight.astimezone(tz)
|
||||
|
|
|
|||
Loading…
Reference in New Issue