diff --git a/CLAUDE.md b/CLAUDE.md index 1b649ca9a0..6d0cb6e1ad 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -59,6 +59,7 @@ pnpm test # Run Jest tests - Use type hints for all functions and class attributes - No `Any` types unless absolutely necessary - Implement special methods (`__repr__`, `__str__`) appropriately +- **Logging**: Never use `str(e)` in `logger.exception()` calls. Use `logger.exception("message", exc_info=e)` instead ### TypeScript/JavaScript diff --git a/api/commands.py b/api/commands.py index 89fef39d25..4fa566f704 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1207,6 +1207,55 @@ def setup_system_tool_oauth_client(provider, client_params): click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) +@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_trigger_oauth_client(provider, client_params): + """ + Setup system trigger oauth client + """ + from core.plugin.entities.plugin import TriggerProviderID + from models.trigger import TriggerOAuthSystemClient + + provider_id = TriggerProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = ( + db.session.query(TriggerOAuthSystemClient) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = TriggerOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: """ Find draft variables that reference non-existent apps. diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d9f2e45ddf..2cdedf0e9d 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -22,8 +22,8 @@ from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID +from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.entities.tool_entities import CredentialType from libs.helper import StrLen, alphanumeric, uuid_value from libs.login import login_required from services.plugin.oauth_service import OAuthProxyService diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py new file mode 100644 index 0000000000..176f75ed29 --- /dev/null +++ b/api/controllers/console/workspace/trigger_providers.py @@ -0,0 +1,323 @@ +import logging + +from flask_restx import Resource, reqparse +from werkzeug.exceptions import BadRequest, Forbidden + +from controllers.console import api +from controllers.console.wraps import account_initialization_required, setup_required +from core.plugin.entities.plugin import TriggerProviderID +from core.plugin.entities.plugin_daemon import CredentialType +from libs.login import current_user, login_required +from models.account import Account +from services.trigger.trigger_provider_service import TriggerProviderService + +logger = logging.getLogger(__name__) + + +class TriggerProviderListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + """List all trigger providers for the current tenant""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + try: + return TriggerProviderService.list_trigger_providers( + tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider) + ) + except Exception as e: + logger.exception("Error listing trigger providers", exc_info=e) + raise + + +class TriggerProviderCredentialsAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + """Add a new credential instance for a trigger provider""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credential_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("expires_at", type=int, required=False, nullable=True, location="json", default=-1) + args = parser.parse_args() + + try: + # Parse credential type + try: + credential_type = CredentialType(args["credential_type"]) + except ValueError: + raise BadRequest(f"Invalid credential_type. Must be one of: {[t.value for t in CredentialType]}") + + result = TriggerProviderService.add_trigger_provider( + tenant_id=user.current_tenant_id, + user_id=user.id, + provider_id=TriggerProviderID(provider), + credential_type=credential_type, + credentials=args["credentials"], + name=args.get("name"), + expires_at=args.get("expires_at", -1), + ) + + return result + + except ValueError as e: + raise BadRequest(str(e)) + except Exception as e: + logger.exception("Error adding provider credential", exc_info=e) + raise + + +class TriggerProviderCredentialsUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, credential_id): + """Update an existing credential instance""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + args = parser.parse_args() + + try: + result = TriggerProviderService.update_trigger_provider( + tenant_id=user.current_tenant_id, + credential_id=credential_id, + credentials=args.get("credentials"), + name=args.get("name"), + ) + + return result + + except ValueError as e: + raise BadRequest(str(e)) + except Exception as e: + logger.exception("Error updating provider credential", exc_info=e) + raise + + +class TriggerProviderCredentialsDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, credential_id): + """Delete a credential instance""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + try: + result = TriggerProviderService.delete_trigger_provider( + tenant_id=user.current_tenant_id, + credential_id=credential_id, + ) + return result + + except ValueError as e: + raise BadRequest(str(e)) + except Exception as e: + logger.exception("Error deleting provider credential", exc_info=e) + raise + + +class TriggerProviderOAuthAuthorizeApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + """Initiate OAuth authorization flow for a provider""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + try: + context_id = TriggerProviderService.create_oauth_proxy_context( + tenant_id=user.current_tenant_id, + user_id=user.id, + provider_id=TriggerProviderID(provider), + ) + + # TODO: Build OAuth authorization URL + # This will be implemented when we have provider-specific OAuth configs + + return { + "context_id": context_id, + "authorization_url": f"/oauth/authorize?context={context_id}", + } + + except Exception as e: + logger.exception("Error initiating OAuth flow", exc_info=e) + raise + + +class TriggerProviderOAuthRefreshTokenApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, credential_id): + """Refresh OAuth token for a trigger provider credential""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + try: + result = TriggerProviderService.refresh_oauth_token( + tenant_id=user.current_tenant_id, + credential_id=credential_id, + ) + return result + + except ValueError as e: + raise BadRequest(str(e)) + except Exception as e: + logger.exception("Error refreshing OAuth token", exc_info=e) + raise + + +class TriggerProviderOAuthClientManageApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + """Get OAuth client configuration for a provider""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + try: + provider_id = TriggerProviderID(provider) + + # Get custom OAuth client params if exists + custom_params = TriggerProviderService.get_custom_oauth_client_params( + tenant_id=user.current_tenant_id, + provider_id=provider_id, + ) + + # Check if custom client is enabled + is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled( + tenant_id=user.current_tenant_id, + provider_id=provider_id, + ) + + # Check if there's a system OAuth client + system_client = TriggerProviderService.get_oauth_client( + tenant_id=user.current_tenant_id, + provider_id=provider_id, + ) + + return { + "configured": bool(custom_params or system_client), + "custom_configured": bool(custom_params), + "custom_enabled": is_custom_enabled, + "params": custom_params if custom_params else {}, + } + + except Exception as e: + logger.exception("Error getting OAuth client", exc_info=e) + raise + + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + """Configure custom OAuth client for a provider""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") + parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json") + args = parser.parse_args() + + try: + provider_id = TriggerProviderID(provider) + + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id=user.current_tenant_id, + provider_id=provider_id, + client_params=args.get("client_params"), + enabled=args.get("enabled"), + ) + + return result + + except ValueError as e: + raise BadRequest(str(e)) + except Exception as e: + logger.exception("Error configuring OAuth client", exc_info=e) + raise + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider): + """Remove custom OAuth client configuration""" + user = current_user + assert isinstance(user, Account) + assert user.current_tenant_id is not None + if not user.is_admin_or_owner: + raise Forbidden() + + try: + provider_id = TriggerProviderID(provider) + + result = TriggerProviderService.delete_custom_oauth_client_params( + tenant_id=user.current_tenant_id, + provider_id=provider_id, + ) + + return result + + except ValueError as e: + raise BadRequest(str(e)) + except Exception as e: + logger.exception("Error removing OAuth client", exc_info=e) + raise + + +# Trigger provider endpoints +api.add_resource(TriggerProviderListApi, "/workspaces/current/trigger-provider//list") +api.add_resource(TriggerProviderCredentialsAddApi, "/workspaces/current/trigger-provider//add") +api.add_resource( + TriggerProviderCredentialsUpdateApi, "/workspaces/current/trigger-provider/credentials//update" +) +api.add_resource( + TriggerProviderCredentialsDeleteApi, "/workspaces/current/trigger-provider/credentials//delete" +) + +api.add_resource( + TriggerProviderOAuthAuthorizeApi, "/workspaces/current/trigger-provider//oauth/authorize" +) +api.add_resource( + TriggerProviderOAuthRefreshTokenApi, + "/workspaces/current/trigger-provider/credentials//oauth/refresh", +) +api.add_resource( + TriggerProviderOAuthClientManageApi, "/workspaces/current/trigger-provider//oauth/client" +) diff --git a/api/controllers/trigger/__init__.py b/api/controllers/trigger/__init__.py index 9132c0179e..972f28649c 100644 --- a/api/controllers/trigger/__init__.py +++ b/api/controllers/trigger/__init__.py @@ -4,4 +4,4 @@ from flask import Blueprint bp = Blueprint("trigger", __name__, url_prefix="/triggers") # Import routes after blueprint creation to avoid circular imports -from . import webhook +from . import trigger, webhook diff --git a/api/controllers/trigger/trigger.py b/api/controllers/trigger/trigger.py new file mode 100644 index 0000000000..13f4a7e234 --- /dev/null +++ b/api/controllers/trigger/trigger.py @@ -0,0 +1,23 @@ +import logging + +from flask import jsonify, request +from werkzeug.exceptions import NotFound + +from controllers.trigger import bp +from services.trigger_service import TriggerService + +logger = logging.getLogger(__name__) + + +@bp.route("/trigger/webhook/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) +def trigger_webhook(endpoint_id: str): + """ + Handle webhook trigger calls. + """ + try: + return TriggerService.process_webhook(endpoint_id, request) + except ValueError as e: + raise NotFound(str(e)) + except Exception as e: + logger.exception("Webhook processing failed for {endpoint_id}") + return jsonify({"error": "Internal server error", "message": str(e)}), 500 diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 48ec3be5c8..f3c45b09ce 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -68,6 +68,19 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" +class TriggerProviderCredentialCache(ProviderCredentialsCache): + """Cache for trigger provider credentials""" + + def __init__(self, tenant_id: str, provider: str, credential_id: str): + super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider = kwargs["provider"] + credential_id = kwargs["credential_id"] + return f"trigger_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + + class NoOpProviderCredentialCache: """No-op provider credential cache""" diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index a07b58d9ea..966c1f0f23 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -184,6 +184,10 @@ class ToolProviderID(GenericProviderID): self.plugin_name = f"{self.provider_name}_tool" +class TriggerProviderID(GenericProviderID): + pass + + class PluginDependency(BaseModel): class Type(enum.StrEnum): Github = PluginInstallationSource.Github.value diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 16ab661092..f2c709b7af 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -1,3 +1,4 @@ +import enum from collections.abc import Mapping, Sequence from datetime import datetime from enum import StrEnum @@ -13,6 +14,7 @@ from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin +from core.trigger.entities import TriggerProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) @@ -196,3 +198,43 @@ class PluginListResponse(BaseModel): class PluginDynamicSelectOptionsResponse(BaseModel): options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.") + + +class PluginTriggerProviderEntity(BaseModel): + provider: str + plugin_unique_identifier: str + plugin_id: str + declaration: TriggerProviderEntity + + +class CredentialType(enum.StrEnum): + API_KEY = "api-key" + OAUTH2 = "oauth2" + + def get_name(self): + if self == CredentialType.API_KEY: + return "API KEY" + elif self == CredentialType.OAUTH2: + return "AUTH" + else: + return self.value.replace("-", " ").upper() + + def is_editable(self): + return self == CredentialType.API_KEY + + def is_validate_allowed(self): + return self == CredentialType.API_KEY + + @classmethod + def values(cls): + return [item.value for item in cls] + + @classmethod + def of(cls, credential_type: str) -> "CredentialType": + type_name = credential_type.lower() + if type_name == "api-key": + return cls.API_KEY + elif type_name == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 4c1558efcc..75d0e15ed1 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -4,10 +4,10 @@ from typing import Any, Optional from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID -from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity +from core.plugin.entities.plugin_daemon import CredentialType, PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient from core.plugin.utils.chunk_merger import merge_blob_chunks -from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter class PluginToolManager(BasePluginClient): diff --git a/api/core/plugin/impl/trigger.py b/api/core/plugin/impl/trigger.py new file mode 100644 index 0000000000..d0826daf56 --- /dev/null +++ b/api/core/plugin/impl/trigger.py @@ -0,0 +1,68 @@ +from typing import Any + +from core.plugin.entities.plugin import ToolProviderID +from core.plugin.entities.plugin_daemon import PluginToolProviderEntity, PluginTriggerProviderEntity +from core.plugin.impl.base import BasePluginClient + + +class PluginTriggerManager(BasePluginClient): + def fetch_trigger_providers(self, tenant_id: str) -> list[PluginTriggerProviderEntity]: + """ + Fetch tool providers for the given tenant. + """ + + def transformer(json_response: dict[str, Any]) -> dict: + for provider in json_response.get("data", []): + declaration = provider.get("declaration", {}) or {} + provider_name = declaration.get("identity", {}).get("name") + for tool in declaration.get("tools", []): + tool["identity"]["provider"] = provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/tools", + list[PluginToolProviderEntity], + params={"page": 1, "page_size": 256}, + transformer=transformer, + ) + + for provider in response: + provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in provider.declaration.tools: + tool.identity.provider = provider.declaration.identity.name + + return response + + def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity: + """ + Fetch tool provider for the given tenant and plugin. + """ + tool_provider_id = ToolProviderID(provider) + + def transformer(json_response: dict[str, Any]) -> dict: + data = json_response.get("data") + if data: + for tool in data.get("declaration", {}).get("tools", []): + tool["identity"]["provider"] = tool_provider_id.provider_name + + return json_response + + response = self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/tool", + PluginToolProviderEntity, + params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id}, + transformer=transformer, + ) + + response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}" + + # override the provider name for each tool to plugin_id/provider_name + for tool in response.declaration.tools: + tool.identity.provider = response.declaration.identity.name + + return response diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index ddec7b1329..bf08c0398d 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,8 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.entities.tool_entities import ToolInvokeFrom class ToolRuntime(BaseModel): diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index a70ded9efd..40b36582f0 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -4,11 +4,11 @@ from typing import Any from core.entities.provider_entities import ProviderConfig from core.helper.module_import_helper import load_single_subclass_from_source +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ( - CredentialType, OAuthSchema, ToolEntity, ToolProviderEntity, diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 48015c04ee..6492f1e455 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -4,9 +4,10 @@ from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import CredentialType, ToolProviderType +from core.tools.entities.tool_entities import ToolProviderType class ToolApiEntity(BaseModel): diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index df599a09a3..11f9956f19 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -476,36 +476,3 @@ class ToolSelector(BaseModel): def to_plugin_parameter(self) -> dict[str, Any]: return self.model_dump() - - -class CredentialType(enum.StrEnum): - API_KEY = "api-key" - OAUTH2 = "oauth2" - - def get_name(self): - if self == CredentialType.API_KEY: - return "API KEY" - elif self == CredentialType.OAUTH2: - return "AUTH" - else: - return self.value.replace("-", " ").upper() - - def is_editable(self): - return self == CredentialType.API_KEY - - def is_validate_allowed(self): - return self == CredentialType.API_KEY - - @classmethod - def values(cls): - return [item.value for item in cls] - - @classmethod - def of(cls, credential_type: str) -> "CredentialType": - type_name = credential_type.lower() - if type_name == "api-key": - return cls.API_KEY - elif type_name == "oauth2": - return cls.OAUTH2 - else: - raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index b338a779ac..3924b0133b 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -37,6 +37,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -47,7 +48,6 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, - CredentialType, ToolInvokeFrom, ToolParameter, ToolProviderType, diff --git a/api/core/trigger/__init__.py b/api/core/trigger/__init__.py new file mode 100644 index 0000000000..1e5b8bb445 --- /dev/null +++ b/api/core/trigger/__init__.py @@ -0,0 +1 @@ +# Core trigger module initialization diff --git a/api/core/trigger/entities.py b/api/core/trigger/entities.py new file mode 100644 index 0000000000..24e18e4e9b --- /dev/null +++ b/api/core/trigger/entities.py @@ -0,0 +1,244 @@ +from collections.abc import Mapping +from enum import StrEnum +from typing import Any, Optional, Union + +from pydantic import BaseModel, Field + +from core.tools.entities.common_entities import I18nObject + + +class TriggerParameterOption(BaseModel): + """ + The option of the trigger parameter + """ + + value: str = Field(..., description="The value of the option") + label: I18nObject = Field(..., description="The label of the option") + + +class TriggerParameterType(StrEnum): + """The type of the parameter""" + + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + SELECT = "select" + FILE = "file" + FILES = "files" + MODEL_SELECTOR = "model-selector" + APP_SELECTOR = "app-selector" + OBJECT = "object" + ARRAY = "array" + DYNAMIC_SELECT = "dynamic-select" + + +class ParameterAutoGenerate(BaseModel): + """Auto generation configuration for parameters""" + + enabled: bool = Field(default=False, description="Whether auto generation is enabled") + template: Optional[str] = Field(default=None, description="Template for auto generation") + + +class ParameterTemplate(BaseModel): + """Template configuration for parameters""" + + value: str = Field(..., description="Template value") + type: str = Field(default="jinja2", description="Template type") + + +class TriggerParameter(BaseModel): + """ + The parameter of the trigger + """ + + name: str = Field(..., description="The name of the parameter") + label: I18nObject = Field(..., description="The label presented to the user") + type: TriggerParameterType = Field(..., description="The type of the parameter") + auto_generate: Optional[ParameterAutoGenerate] = Field( + default=None, description="The auto generate of the parameter" + ) + template: Optional[ParameterTemplate] = Field(default=None, description="The template of the parameter") + scope: Optional[str] = None + required: Optional[bool] = False + default: Union[int, float, str, None] = None + min: Union[float, int, None] = None + max: Union[float, int, None] = None + precision: Optional[int] = None + options: Optional[list[TriggerParameterOption]] = None + description: Optional[I18nObject] = None + + +class TriggerProviderIdentity(BaseModel): + """ + The identity of the trigger provider + """ + + author: str = Field(..., description="The author of the trigger provider") + name: str = Field(..., description="The name of the trigger provider") + label: I18nObject = Field(..., description="The label of the trigger provider") + description: I18nObject = Field(..., description="The description of the trigger provider") + icon: Optional[str] = Field(default=None, description="The icon of the trigger provider") + tags: list[str] = Field(default_factory=list, description="The tags of the trigger provider") + + +class TriggerIdentity(BaseModel): + """ + The identity of the trigger + """ + + author: str = Field(..., description="The author of the trigger") + name: str = Field(..., description="The name of the trigger") + label: I18nObject = Field(..., description="The label of the trigger") + + +class TriggerDescription(BaseModel): + """ + The description of the trigger + """ + + human: I18nObject = Field(..., description="Human readable description") + llm: I18nObject = Field(..., description="LLM readable description") + + +class TriggerConfigurationExtraPython(BaseModel): + """Python configuration for trigger""" + + source: str = Field(..., description="The source file path for the trigger implementation") + + +class TriggerConfigurationExtra(BaseModel): + """ + The extra configuration for trigger + """ + + +class TriggerEntity(BaseModel): + """ + The configuration of a trigger + """ + + python: TriggerConfigurationExtraPython + identity: TriggerIdentity = Field(..., description="The identity of the trigger") + parameters: list[TriggerParameter] = Field(default=[], description="The parameters of the trigger") + description: TriggerDescription = Field(..., description="The description of the trigger") + extra: TriggerConfigurationExtra = Field(..., description="The extra configuration of the trigger") + output_schema: Optional[Mapping[str, Any]] = Field( + default=None, description="The output schema that this trigger produces" + ) + + +class TriggerProviderConfigurationExtraPython(BaseModel): + """Python configuration for trigger provider""" + + source: str = Field(..., description="The source file path for the trigger provider implementation") + + +class TriggerProviderConfigurationExtra(BaseModel): + """ + The extra configuration for trigger provider + """ + + python: TriggerProviderConfigurationExtraPython + + +class OAuthSchema(BaseModel): + """OAuth configuration schema""" + + authorization_url: str = Field(..., description="OAuth authorization URL") + token_url: str = Field(..., description="OAuth token URL") + client_id: str = Field(..., description="OAuth client ID") + client_secret: str = Field(..., description="OAuth client secret") + redirect_uri: Optional[str] = Field(default=None, description="OAuth redirect URI") + scope: Optional[str] = Field(default=None, description="OAuth scope") + + +class ProviderConfig(BaseModel): + """Provider configuration item""" + + name: str = Field(..., description="Configuration field name") + type: str = Field(..., description="Configuration field type") + required: bool = Field(default=False, description="Whether this field is required") + default: Any = Field(default=None, description="Default value") + label: Optional[I18nObject] = Field(default=None, description="Field label") + description: Optional[I18nObject] = Field(default=None, description="Field description") + options: Optional[list[dict[str, Any]]] = Field(default=None, description="Options for select type") + + +class TriggerProviderEntity(BaseModel): + """ + The configuration of a trigger provider + """ + + identity: TriggerProviderIdentity = Field(..., description="The identity of the trigger provider") + credentials_schema: list[ProviderConfig] = Field( + default_factory=list, + description="The credentials schema of the trigger provider", + ) + oauth_schema: Optional[OAuthSchema] = Field( + default=None, + description="The OAuth schema of the trigger provider if OAuth is supported", + ) + subscription_schema: list[ProviderConfig] = Field( + default_factory=list, + description="The subscription schema for trigger(webhook, polling, etc.) subscription parameters", + ) + triggers: list[TriggerEntity] = Field(default=[], description="The triggers of the trigger provider") + extra: TriggerProviderConfigurationExtra = Field(..., description="The extra configuration of the trigger provider") + + +class Subscription(BaseModel): + """ + Result of a successful trigger subscription operation. + + Contains all information needed to manage the subscription lifecycle. + """ + + expire_at: int = Field( + ..., description="The timestamp when the subscription will expire, this for refresh the subscription" + ) + + metadata: dict[str, Any] = Field( + ..., description="Metadata about the subscription in the external service, defined in subscription_schema" + ) + + +class Unsubscription(BaseModel): + """ + Result of a trigger unsubscription operation. + + Provides detailed information about the unsubscription attempt, + including success status and error details if failed. + """ + + success: bool = Field(..., description="Whether the unsubscription was successful") + + message: Optional[str] = Field( + None, + description="Human-readable message about the operation result. " + "Success message for successful operations, " + "detailed error information for failures.", + ) + + +# Export all entities +__all__ = [ + "OAuthSchema", + "ParameterAutoGenerate", + "ParameterTemplate", + "ProviderConfig", + "Subscription", + "TriggerConfigurationExtra", + "TriggerConfigurationExtraPython", + "TriggerDescription", + "TriggerEntity", + "TriggerEntity", + "TriggerIdentity", + "TriggerParameter", + "TriggerParameterOption", + "TriggerParameterType", + "TriggerProviderConfigurationExtra", + "TriggerProviderConfigurationExtraPython", + "TriggerProviderEntity", + "TriggerProviderIdentity", + "Unsubscription", +] diff --git a/api/core/trigger/provider.py b/api/core/trigger/provider.py new file mode 100644 index 0000000000..92b8ccbd91 --- /dev/null +++ b/api/core/trigger/provider.py @@ -0,0 +1,199 @@ +""" +Trigger Provider Controller for managing trigger providers +""" + +import logging +import time +from typing import Optional + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.entities import ( + ProviderConfig, + Subscription, + TriggerEntity, + TriggerProviderEntity, + TriggerProviderIdentity, + Unsubscription, +) + +logger = logging.getLogger(__name__) + + +class TriggerProviderController: + """ + Controller for plugin trigger providers + """ + + def __init__( + self, + entity: TriggerProviderEntity, + plugin_id: str, + plugin_unique_identifier: str, + tenant_id: str, + ): + """ + Initialize plugin trigger provider controller + + :param entity: Trigger provider entity + :param plugin_id: Plugin ID + :param plugin_unique_identifier: Plugin unique identifier + :param tenant_id: Tenant ID + """ + self.entity = entity + self.tenant_id = tenant_id + self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier + + @property + def identity(self) -> TriggerProviderIdentity: + """Get provider identity""" + return self.entity.identity + + def get_triggers(self) -> list[TriggerEntity]: + """ + Get all triggers for this provider + + :return: List of trigger entities + """ + return self.entity.triggers + + def get_trigger(self, trigger_name: str) -> Optional[TriggerEntity]: + """ + Get a specific trigger by name + + :param trigger_name: Trigger name + :return: Trigger entity or None + """ + for trigger in self.entity.triggers: + if trigger.identity.name == trigger_name: + return trigger + return None + + def get_credentials_schema(self) -> list[ProviderConfig]: + """ + Get credentials schema for this provider + + :return: List of provider config schemas + """ + return self.entity.credentials_schema + + def get_subscription_schema(self) -> list[ProviderConfig]: + """ + Get subscription schema for this provider + + :return: List of subscription config schemas + """ + return self.entity.subscription_schema + + def validate_credentials(self, credentials: dict) -> None: + """ + Validate credentials against schema + + :param credentials: Credentials to validate + :raises ValueError: If credentials are invalid + """ + for config in self.entity.credentials_schema: + if config.required and config.name not in credentials: + raise ValueError(f"Missing required credential field: {config.name}") + + def get_supported_credential_types(self) -> list[CredentialType]: + """ + Get supported credential types for this provider. + + :return: List of supported credential types + """ + types = [] + if self.entity.oauth_schema: + types.append(CredentialType.OAUTH2) + if self.entity.credentials_schema: + types.append(CredentialType.API_KEY) + return types + + def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + """ + Get credentials schema by credential type + + :param credential_type: The type of credential (oauth or api_key) + :return: List of provider config schemas + """ + if credential_type == CredentialType.OAUTH2.value: + return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] + if credential_type == CredentialType.API_KEY.value: + return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + raise ValueError(f"Invalid credential type: {credential_type}") + + def get_oauth_client_schema(self) -> list[ProviderConfig]: + """ + Get OAuth client schema for this provider + + :return: List of OAuth client config schemas + """ + return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] + + @property + def need_credentials(self) -> bool: + """Check if this provider needs credentials""" + return len(self.get_supported_credential_types()) > 0 + + def execute_trigger(self, trigger_name: str, parameters: dict, credentials: dict) -> dict: + """ + Execute a trigger through plugin runtime + + :param trigger_name: Trigger name + :param parameters: Trigger parameters + :param credentials: Provider credentials + :return: Execution result + """ + logger.info("Executing trigger %s for plugin %s", trigger_name, self.plugin_id) + return { + "success": True, + "trigger": trigger_name, + "plugin": self.plugin_id, + "result": "Trigger executed successfully", + } + + def subscribe_trigger(self, trigger_name: str, subscription_params: dict, credentials: dict) -> Subscription: + """ + Subscribe to a trigger through plugin runtime + + :param trigger_name: Trigger name + :param subscription_params: Subscription parameters + :param credentials: Provider credentials + :return: Subscription result + """ + logger.info("Subscribing to trigger %s for plugin %s", trigger_name, self.plugin_id) + return Subscription( + expire_at=int(time.time()) + 86400, # 24 hours from now + metadata={ + "subscription_id": f"{self.plugin_id}_{trigger_name}_{time.time()}", + "webhook_url": f"/triggers/webhook/{self.plugin_id}/{trigger_name}", + **subscription_params, + }, + ) + + def unsubscribe_trigger(self, trigger_name: str, subscription_metadata: dict, credentials: dict) -> Unsubscription: + """ + Unsubscribe from a trigger through plugin runtime + + :param trigger_name: Trigger name + :param subscription_metadata: Subscription metadata + :param credentials: Provider credentials + :return: Unsubscription result + """ + logger.info("Unsubscribing from trigger %s for plugin %s", trigger_name, self.plugin_id) + return Unsubscription(success=True, message=f"Successfully unsubscribed from trigger {trigger_name}") + + def handle_webhook(self, webhook_path: str, request_data: dict, credentials: dict) -> dict: + """ + Handle incoming webhook through plugin runtime + + :param webhook_path: Webhook path + :param request_data: Request data + :param credentials: Provider credentials + :return: Webhook handling result + """ + logger.info("Handling webhook for path %s for plugin %s", webhook_path, self.plugin_id) + return {"success": True, "path": webhook_path, "plugin": self.plugin_id, "data_received": request_data} + + +__all__ = ["TriggerProviderController"] diff --git a/api/core/trigger/trigger_manager.py b/api/core/trigger/trigger_manager.py new file mode 100644 index 0000000000..91e5a43519 --- /dev/null +++ b/api/core/trigger/trigger_manager.py @@ -0,0 +1,360 @@ +""" +Trigger Manager for loading and managing trigger providers and triggers +""" + +import logging +from typing import Optional + +from core.trigger.entities import ( + ProviderConfig, + TriggerEntity, +) +from core.trigger.plugin_trigger import PluginTriggerController +from core.trigger.provider import PluginTriggerProviderController + +logger = logging.getLogger(__name__) + + +class TriggerManager: + """ + Manager for trigger providers and triggers + """ + + @classmethod + def list_plugin_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]: + """ + List all plugin trigger providers for a tenant + + :param tenant_id: Tenant ID + :return: List of trigger provider controllers + """ + manager = PluginTriggerController() + provider_entities = manager.fetch_trigger_providers(tenant_id) + + controllers = [] + for provider in provider_entities: + try: + controller = PluginTriggerProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + controllers.append(controller) + except Exception as e: + logger.exception("Failed to load trigger provider {provider.plugin_id}") + continue + + return controllers + + @classmethod + def get_plugin_trigger_provider( + cls, tenant_id: str, plugin_id: str, provider_name: str + ) -> Optional[PluginTriggerProviderController]: + """ + Get a specific plugin trigger provider + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :return: Trigger provider controller or None + """ + manager = PluginTriggerManager() + provider = manager.fetch_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + return None + + try: + return PluginTriggerProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + except Exception as e: + logger.exception("Failed to load trigger provider") + return None + + @classmethod + def list_all_trigger_providers(cls, tenant_id: str) -> list[PluginTriggerProviderController]: + """ + List all trigger providers (plugin and builtin) + + :param tenant_id: Tenant ID + :return: List of all trigger provider controllers + """ + providers = [] + + # Get plugin providers + plugin_providers = cls.list_plugin_trigger_providers(tenant_id) + providers.extend(plugin_providers) + + # TODO: Add builtin providers when implemented + # builtin_providers = cls.list_builtin_trigger_providers(tenant_id) + # providers.extend(builtin_providers) + + return providers + + @classmethod + def list_triggers_by_provider(cls, tenant_id: str, plugin_id: str, provider_name: str) -> list[TriggerEntity]: + """ + List all triggers for a specific provider + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :return: List of trigger entities + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + return [] + + return provider.get_triggers() + + @classmethod + def get_trigger( + cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str + ) -> Optional[TriggerEntity]: + """ + Get a specific trigger + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :param trigger_name: Trigger name + :return: Trigger entity or None + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + return None + + return provider.get_trigger(trigger_name) + + @classmethod + def validate_trigger_credentials( + cls, tenant_id: str, plugin_id: str, provider_name: str, credentials: dict + ) -> tuple[bool, str]: + """ + Validate trigger provider credentials + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :param credentials: Credentials to validate + :return: Tuple of (is_valid, error_message) + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + return False, "Provider not found" + + try: + provider.validate_credentials(credentials) + return True, "" + except Exception as e: + return False, str(e) + + @classmethod + def execute_trigger( + cls, tenant_id: str, plugin_id: str, provider_name: str, trigger_name: str, parameters: dict, credentials: dict + ) -> dict: + """ + Execute a trigger + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :param trigger_name: Trigger name + :param parameters: Trigger parameters + :param credentials: Provider credentials + :return: Trigger execution result + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + raise ValueError(f"Provider {plugin_id}/{provider_name} not found") + + trigger = provider.get_trigger(trigger_name) + if not trigger: + raise ValueError(f"Trigger {trigger_name} not found in provider {provider_name}") + + return provider.execute_trigger(trigger_name, parameters, credentials) + + @classmethod + def subscribe_trigger( + cls, + tenant_id: str, + plugin_id: str, + provider_name: str, + trigger_name: str, + subscription_params: dict, + credentials: dict, + ) -> dict: + """ + Subscribe to a trigger (e.g., register webhook) + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :param trigger_name: Trigger name + :param subscription_params: Subscription parameters + :param credentials: Provider credentials + :return: Subscription result + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + raise ValueError(f"Provider {plugin_id}/{provider_name} not found") + + return provider.subscribe_trigger(trigger_name, subscription_params, credentials) + + @classmethod + def unsubscribe_trigger( + cls, + tenant_id: str, + plugin_id: str, + provider_name: str, + trigger_name: str, + subscription_metadata: dict, + credentials: dict, + ) -> dict: + """ + Unsubscribe from a trigger + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :param trigger_name: Trigger name + :param subscription_metadata: Subscription metadata from subscribe operation + :param credentials: Provider credentials + :return: Unsubscription result + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + raise ValueError(f"Provider {plugin_id}/{provider_name} not found") + + return provider.unsubscribe_trigger(trigger_name, subscription_metadata, credentials) + + @classmethod + def handle_webhook( + cls, + tenant_id: str, + plugin_id: str, + provider_name: str, + webhook_path: str, + request_data: dict, + credentials: dict, + ) -> dict: + """ + Handle incoming webhook for a trigger + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :param webhook_path: Webhook path + :param request_data: Webhook request data + :param credentials: Provider credentials + :return: Webhook handling result + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + raise ValueError(f"Provider {plugin_id}/{provider_name} not found") + + return provider.handle_webhook(webhook_path, request_data, credentials) + + @classmethod + def get_provider_credentials_schema( + cls, tenant_id: str, plugin_id: str, provider_name: str + ) -> list[ProviderConfig]: + """ + Get provider credentials schema + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :return: List of provider config schemas + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + return [] + + return provider.get_credentials_schema() + + @classmethod + def get_provider_subscription_schema( + cls, tenant_id: str, plugin_id: str, provider_name: str + ) -> list[ProviderConfig]: + """ + Get provider subscription schema + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :return: List of subscription config schemas + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + return [] + + return provider.get_subscription_schema() + + @classmethod + def get_provider_info(cls, tenant_id: str, plugin_id: str, provider_name: str) -> Optional[dict]: + """ + Get provider information + + :param tenant_id: Tenant ID + :param plugin_id: Plugin ID + :param provider_name: Provider name + :return: Provider info dict or None + """ + provider = cls.get_plugin_trigger_provider(tenant_id, plugin_id, provider_name) + + if not provider: + return None + + return { + "plugin_id": plugin_id, + "provider_name": provider_name, + "identity": provider.entity.identity.model_dump() if provider.entity.identity else {}, + "credentials_schema": [c.model_dump() for c in provider.entity.credentials_schema], + "subscription_schema": [s.model_dump() for s in provider.entity.subscription_schema], + "oauth_enabled": provider.entity.oauth_schema is not None, + "trigger_count": len(provider.entity.triggers), + "triggers": [t.identity.model_dump() for t in provider.entity.triggers], + } + + @classmethod + def list_providers_for_workflow(cls, tenant_id: str) -> list[dict]: + """ + List trigger providers suitable for workflow usage + + :param tenant_id: Tenant ID + :return: List of provider info dicts + """ + providers = cls.list_all_trigger_providers(tenant_id) + + result = [] + for provider in providers: + info = { + "plugin_id": provider.plugin_id, + "provider_name": provider.entity.identity.name, + "label": provider.entity.identity.label.model_dump(), + "description": provider.entity.identity.description.model_dump(), + "icon": provider.entity.identity.icon, + "trigger_count": len(provider.entity.triggers), + } + result.append(info) + + return result + + +# Export +__all__ = ["TriggerManager"] diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 8904ff7a92..a71084e8f3 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -20,6 +20,7 @@ def init_app(app: DifyApp): reset_encrypt_key_pair, reset_password, setup_system_tool_oauth_client, + setup_system_trigger_oauth_client, upgrade_db, vdb_migrate, ) @@ -43,6 +44,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, remove_orphaned_files_on_storage, setup_system_tool_oauth_client, + setup_system_trigger_oauth_client, cleanup_orphaned_draft_variables, ] for cmd in cmds_to_register: diff --git a/api/models/trigger.py b/api/models/trigger.py new file mode 100644 index 0000000000..09e066c93e --- /dev/null +++ b/api/models/trigger.py @@ -0,0 +1,97 @@ +import json +from datetime import UTC, datetime +from typing import cast + +import sqlalchemy as sa +from sqlalchemy import DateTime, Index, Integer, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column + +from core.plugin.entities.plugin_daemon import CredentialType +from models.base import Base +from models.types import StringUUID + + +class TriggerProvider(Base): + """ + Trigger provider model for managing credentials + Supports multiple credential instances per provider + """ + + __tablename__ = "trigger_providers" + __table_args__ = (Index("idx_trigger_providers_tenant_provider", "tenant_id", "provider_id"),) + + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_id: Mapped[str] = mapped_column( + String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)" + ) + credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") + encrypted_credentials: Mapped[str] = mapped_column(Text, nullable=False, comment="Encrypted credentials JSON") + name: Mapped[str] = mapped_column(String(255), nullable=False, comment="Credential instance name") + expires_at: Mapped[int] = mapped_column( + Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never" + ) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.now(), onupdate=func.now() + ) + + @property + def credentials(self) -> dict: + """Get credentials as dict (still encrypted)""" + try: + return json.loads(self.encrypted_credentials) if self.encrypted_credentials else {} + except (json.JSONDecodeError, TypeError): + return {} + + @property + def credentials_str(self) -> str: + """Get credentials as string""" + return self.encrypted_credentials or "{}" + + def is_oauth_expired(self) -> bool: + """Check if OAuth token is expired""" + if self.credential_type != CredentialType.OAUTH2.value: + return False + if self.expires_at == -1: + return False + # Check if token expires in next 60 seconds + return (self.expires_at - 60) < int(datetime.now(UTC).timestamp()) + + +# system level trigger oauth client params +class TriggerOAuthSystemClient(Base): + __tablename__ = "trigger_oauth_system_clients" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="trigger_oauth_system_client_pkey"), + sa.UniqueConstraint("plugin_id", "provider", name="trigger_oauth_system_client_plugin_id_provider_idx"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) + # oauth params of the trigger provider + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) + + +# tenant level trigger oauth client params (client_id, client_secret, etc.) +class TriggerOAuthTenantClient(Base): + __tablename__ = "trigger_oauth_tenant_clients" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="trigger_oauth_tenant_client_pkey"), + sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_trigger_oauth_tenant_client"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + # tenant id + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + # oauth params of the trigger provider + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 71bc50017f..b4bc4df656 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -13,6 +13,7 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ( @@ -21,7 +22,6 @@ from core.tools.entities.api_entities import ( ToolProviderCredentialApiEntity, ToolProviderCredentialInfoApiEntity, ) -from core.tools.entities.tool_entities import CredentialType from core.tools.errors import ToolProviderNotFoundError from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager @@ -39,7 +39,6 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 - __DEFAULT_EXPIRES_AT__ = 2147483647 @staticmethod def delete_custom_oauth_client_params(tenant_id: str, provider: str): @@ -278,9 +277,7 @@ class BuiltinToolManageService: encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), credential_type=api_type.value, name=name, - expires_at=expires_at - if expires_at is not None - else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__, + expires_at=expires_at if expires_at is not None else -1, ) session.add(db_provider) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 52fbc0979c..5a05907f37 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -7,6 +7,7 @@ from yarl import URL from configs import dify_config from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -16,7 +17,6 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, - CredentialType, ToolParameter, ToolProviderType, ) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py new file mode 100644 index 0000000000..128aaceabf --- /dev/null +++ b/api/services/trigger/trigger_provider_service.py @@ -0,0 +1,588 @@ +import json +import logging +import re +from collections.abc import Mapping +from typing import Any, Optional + +from sqlalchemy import desc +from sqlalchemy.orm import Session + +from configs import dify_config +from constants import HIDDEN_VALUE, UNKNOWN_VALUE +from core.helper.provider_cache import NoOpProviderCredentialCache +from core.plugin.entities.plugin import TriggerProviderID +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.oauth import OAuthHandler +from core.plugin.service import PluginService +from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter +from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.trigger.trigger_manager import TriggerManager +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.trigger import TriggerOAuthSystemClient, TriggerOAuthTenantClient, TriggerProvider +from services.plugin.oauth_service import OAuthProxyService + +logger = logging.getLogger(__name__) + + +class TriggerProviderService: + """Service for managing trigger providers and credentials""" + + __MAX_TRIGGER_PROVIDER_COUNT__ = 100 + + @classmethod + def list_trigger_providers(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[TriggerProvider]: + """List all trigger providers for the current tenant""" + # TODO fetch trigger plugin controller + + # TODO fetch all trigger plugin credentials + with Session(db.engine, autoflush=False) as session: + return session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).all() + + @classmethod + def add_trigger_provider( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + credential_type: CredentialType, + credentials: dict, + name: Optional[str] = None, + expires_at: int = -1, + ) -> dict: + """ + Add a new trigger provider with credentials. + Supports multiple credential instances per provider. + + :param tenant_id: Tenant ID + :param provider_id: Provider identifier (e.g., "plugin_id/provider_name") + :param credential_type: Type of credential (oauth or api_key) + :param credentials: Credential data to encrypt and store + :param name: Optional name for this credential instance + :param expires_at: OAuth token expiration timestamp + :return: Success response + """ + try: + with Session(db.engine) as session: + # Use distributed lock to prevent race conditions + lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}" + with redis_client.lock(lock_key, timeout=20): + # Check provider count limit + provider_count = ( + session.query(TriggerProvider).filter_by(tenant_id=tenant_id, provider_id=provider_id).count() + ) + + if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__: + raise ValueError( + f"Maximum number of providers ({cls.__MAX_TRIGGER_PROVIDER_COUNT__}) " + f"reached for {provider_id}" + ) + + # Generate name if not provided + if not name: + name = cls._generate_provider_name( + session=session, + tenant_id=tenant_id, + provider_id=provider_id, + credential_type=credential_type, + ) + else: + # Check if name already exists + existing = ( + session.query(TriggerProvider) + .filter_by(tenant_id=tenant_id, provider_id=provider_id, name=name) + .first() + ) + if existing: + raise ValueError(f"Credential name '{name}' already exists for this provider") + + # Create encrypter for credentials + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[], # We'll define schema later in TriggerProvider classes + cache=NoOpProviderCredentialCache(), + ) + + # Create provider record + db_provider = TriggerProvider( + tenant_id=tenant_id, + user_id=user_id, + provider_id=provider_id, + credential_type=credential_type.value, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + name=name, + expires_at=expires_at, + ) + + session.add(db_provider) + session.commit() + + return {"result": "success", "id": str(db_provider.id)} + + except Exception as e: + logger.exception("Failed to add trigger provider") + raise ValueError(str(e)) + + @classmethod + def update_trigger_provider( + cls, + tenant_id: str, + credential_id: str, + credentials: Optional[dict] = None, + name: Optional[str] = None, + ) -> dict: + """ + Update an existing trigger provider's credentials or name. + + :param tenant_id: Tenant ID + :param credential_id: Credential instance ID + :param credentials: New credentials (optional) + :param name: New name (optional) + :return: Success response + """ + with Session(db.engine) as session: + # Get provider + db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + + if not db_provider: + raise ValueError(f"Trigger provider credential {credential_id} not found") + + try: + # Update credentials if provided + if credentials: + encrypter, cache = cls._create_provider_encrypter( + tenant_id=tenant_id, + provider=db_provider, + ) + + # Handle hidden values + original_credentials = encrypter.decrypt(db_provider.credentials) + new_credentials = { + key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } + + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials)) + cache.delete() + + # Update name if provided + if name and name != db_provider.name: + # Check if name already exists + existing = ( + session.query(TriggerProvider) + .filter_by(tenant_id=tenant_id, provider_id=db_provider.provider_id, name=name) + .filter(TriggerProvider.id != credential_id) + .first() + ) + if existing: + raise ValueError(f"Credential name '{name}' already exists") + + db_provider.name = name + + session.commit() + return {"result": "success"} + + except Exception as e: + session.rollback() + raise ValueError(str(e)) + + @classmethod + def delete_trigger_provider(cls, tenant_id: str, credential_id: str) -> dict: + """ + Delete a trigger provider credential. + + :param tenant_id: Tenant ID + :param credential_id: Credential instance ID + :return: Success response + """ + with Session(db.engine) as session: + db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + if not db_provider: + raise ValueError(f"Trigger provider credential {credential_id} not found") + + # Delete provider + session.delete(db_provider) + session.commit() + + # Clear cache + _, cache = cls._create_provider_encrypter(tenant_id, db_provider) + cache.delete() + + return {"result": "success"} + + @classmethod + def refresh_oauth_token( + cls, + tenant_id: str, + credential_id: str, + ) -> dict: + """ + Refresh OAuth token for a trigger provider. + + :param tenant_id: Tenant ID + :param credential_id: Credential instance ID + :return: New token info + """ + with Session(db.engine) as session: + db_provider = session.query(TriggerProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + + if not db_provider: + raise ValueError(f"Trigger provider credential {credential_id} not found") + + if db_provider.credential_type != CredentialType.OAUTH2.value: + raise ValueError("Only OAuth credentials can be refreshed") + + # Parse provider ID + provider_id = TriggerProviderID(db_provider.provider_id) + + # Create encrypter + encrypter, cache = cls._create_provider_encrypter( + tenant_id=tenant_id, + provider=db_provider, + ) + + # Decrypt current credentials + current_credentials = encrypter.decrypt(db_provider.credentials) + + # Get OAuth client configuration + redirect_uri = ( + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{db_provider.provider_id}/trigger/callback" + ) + system_credentials = cls.get_oauth_client(tenant_id, provider_id) + + # Refresh token + oauth_handler = OAuthHandler() + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=db_provider.user_id, + plugin_id=provider_id.plugin_id, + provider=provider_id.provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=current_credentials, + ) + + # Update credentials + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(dict(refreshed_credentials.credentials))) + db_provider.expires_at = refreshed_credentials.expires_at + session.commit() + + # Clear cache + cache.delete() + + return { + "result": "success", + "expires_at": refreshed_credentials.expires_at, + } + + @classmethod + def get_oauth_client(cls, tenant_id: str, provider_id: TriggerProviderID) -> Optional[Mapping[str, Any]]: + """ + Get OAuth client configuration for a provider. + First tries tenant-level OAuth, then falls back to system OAuth. + + :param tenant_id: Tenant ID + :param provider_id: Provider identifier + :return: OAuth client configuration or None + """ + # Get trigger provider controller to access schema + provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id) + + # Create encrypter for OAuth client params + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + + with Session(db.engine, autoflush=False) as session: + # First check for tenant-specific OAuth client + tenant_client: TriggerOAuthTenantClient | None = ( + session.query(TriggerOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + enabled=True, + ) + .first() + ) + + oauth_params: Mapping[str, Any] | None = None + if tenant_client: + oauth_params = encrypter.decrypt(tenant_client.oauth_params) + return oauth_params + + # Only verified plugins can use system OAuth client + is_verified = PluginService.is_plugin_verified(tenant_id, provider_id.plugin_id) + if not is_verified: + return oauth_params + + # Check for system-level OAuth client + system_client: TriggerOAuthSystemClient | None = ( + session.query(TriggerOAuthSystemClient) + .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) + .first() + ) + + if system_client: + try: + oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + except Exception as e: + raise ValueError(f"Error decrypting system oauth params: {e}") + + return oauth_params + + @classmethod + def save_custom_oauth_client_params( + cls, + tenant_id: str, + provider_id: TriggerProviderID, + client_params: Optional[dict] = None, + enabled: Optional[bool] = None, + ) -> dict: + """ + Save or update custom OAuth client parameters for a trigger provider. + + :param tenant_id: Tenant ID + :param provider_id: Provider identifier + :param client_params: OAuth client parameters (client_id, client_secret, etc.) + :param enabled: Enable/disable the custom OAuth client + :return: Success response + """ + if client_params is None and enabled is None: + return {"result": "success"} + + # Get provider controller to access schema + provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id) + + with Session(db.engine) as session: + # Find existing custom client params + custom_client = ( + session.query(TriggerOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=provider_id.plugin_id, + provider=provider_id.provider_name, + ) + .first() + ) + + # Create new record if doesn't exist + if custom_client is None: + custom_client = TriggerOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=provider_id.plugin_id, + provider=provider_id.provider_name, + ) + session.add(custom_client) + + # Update client params if provided + if client_params is not None: + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + + # Handle hidden values + original_params = encrypter.decrypt(custom_client.oauth_params) + new_params: dict = { + key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) + for key, value in client_params.items() + } + custom_client.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params)) + + # Update enabled status if provided + if enabled is not None: + custom_client.enabled = enabled + + session.commit() + + return {"result": "success"} + + @classmethod + def get_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict: + """ + Get custom OAuth client parameters for a trigger provider. + + :param tenant_id: Tenant ID + :param provider_id: Provider identifier + :return: Masked OAuth client parameters + """ + with Session(db.engine) as session: + custom_client = ( + session.query(TriggerOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=provider_id.plugin_id, + provider=provider_id.provider_name, + ) + .first() + ) + + if custom_client is None: + return {} + + # Get provider controller to access schema + provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id) + + # Create encrypter to decrypt and mask values + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + + return encrypter.mask_tool_credentials(encrypter.decrypt(custom_client.oauth_params)) + + @classmethod + def delete_custom_oauth_client_params(cls, tenant_id: str, provider_id: TriggerProviderID) -> dict: + """ + Delete custom OAuth client parameters for a trigger provider. + + :param tenant_id: Tenant ID + :param provider_id: Provider identifier + :return: Success response + """ + with Session(db.engine) as session: + session.query(TriggerOAuthTenantClient).filter_by( + tenant_id=tenant_id, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + ).delete() + session.commit() + + return {"result": "success"} + + @classmethod + def is_oauth_custom_client_enabled(cls, tenant_id: str, provider_id: TriggerProviderID) -> bool: + """ + Check if custom OAuth client is enabled for a trigger provider. + + :param tenant_id: Tenant ID + :param provider_id: Provider identifier + :return: True if enabled, False otherwise + """ + with Session(db.engine, autoflush=False) as session: + custom_client = ( + session.query(TriggerOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=provider_id.plugin_id, + provider=provider_id.provider_name, + enabled=True, + ) + .first() + ) + return custom_client is not None + + @classmethod + def create_oauth_proxy_context( + cls, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + ) -> str: + """ + Create OAuth proxy context for authorization flow. + + :param tenant_id: Tenant ID + :param user_id: User ID + :param provider: Provider identifier + :return: Context ID for OAuth flow + """ + return OAuthProxyService.create_proxy_context( + user_id=user_id, + tenant_id=tenant_id, + plugin_id=provider_id.plugin_id, + provider=provider_id.provider_name, + ) + + @classmethod + def _create_provider_encrypter( + cls, tenant_id: str, provider: TriggerProvider + ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: + """ + Create encrypter and cache for trigger provider credentials + + :param tenant_id: Tenant ID + :param provider: TriggerProvider instance + :return: Tuple of encrypter and cache + """ + from core.helper.provider_cache import TriggerProviderCredentialCache + + # Parse provider ID + provider_id = TriggerProviderID(provider.provider_id) + + # Get trigger provider controller to access schema + provider_controller = TriggerManager.get_trigger_provider(provider_id, tenant_id) + + # Create encrypter with appropriate schema based on credential type + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(provider.credential_type) + ], + cache=TriggerProviderCredentialCache( + tenant_id=tenant_id, + provider=provider.provider_id, + credential_id=provider.id, + ), + ) + + return encrypter, cache + + @classmethod + def _generate_provider_name( + cls, + session: Session, + tenant_id: str, + provider_id: TriggerProviderID, + credential_type: CredentialType, + ) -> str: + """ + Generate a unique name for a provider credential instance. + + :param session: Database session + :param tenant_id: Tenant ID + :param provider: Provider identifier + :param credential_type: Credential type + :return: Generated name + """ + try: + db_providers = ( + session.query(TriggerProvider) + .filter_by( + tenant_id=tenant_id, + provider_id=provider_id, + credential_type=credential_type.value, + ) + .order_by(desc(TriggerProvider.created_at)) + .all() + ) + + # Get base name + base_name = credential_type.get_name() + + # Find existing numbered names + pattern = rf"^{re.escape(base_name)}\s+(\d+)$" + numbers = [] + + for db_provider in db_providers: + if db_provider.name: + match = re.match(pattern, db_provider.name.strip()) + if match: + numbers.append(int(match.group(1))) + + # Generate next number + if not numbers: + return f"{base_name} 1" + + max_number = max(numbers) + return f"{base_name} {max_number + 1}" + + except Exception as e: + logger.warning("Error generating provider name") + return f"{credential_type.get_name()} 1" diff --git a/api/services/trigger_service.py b/api/services/trigger_service.py new file mode 100644 index 0000000000..3247fadda3 --- /dev/null +++ b/api/services/trigger_service.py @@ -0,0 +1,23 @@ +import logging +from typing import Any + +from flask import Request, Response + +logger = logging.getLogger(__name__) + + +class TriggerService: + __MAX_REQUEST_LOG_COUNT__ = 10 + + @classmethod + def process_webhook(cls, webhook_id: str, request: Request) -> Response: + """Extract and process data from incoming webhook request.""" + # TODO redis slidingwindow log, save the recent request log in redis, rollover the log when the window is full + + # TODO find the trigger subscription + + # TODO fetch the trigger controller + + # TODO dispatch by the trigger controller + + # TODO using the dispatch result(events) to invoke the trigger events diff --git a/api/services/workflow/queue_dispatcher.py b/api/services/workflow/queue_dispatcher.py index a3d71bbdd8..782f351d31 100644 --- a/api/services/workflow/queue_dispatcher.py +++ b/api/services/workflow/queue_dispatcher.py @@ -55,9 +55,7 @@ class BaseQueueDispatcher(ABC): True if quota available, False otherwise """ # Check without consuming - remaining = self.rate_limiter.get_remaining_quota( - tenant_id=tenant_id, max_daily_limit=self.get_daily_limit() - ) + remaining = self.rate_limiter.get_remaining_quota(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit()) return remaining > 0 def consume_quota(self, tenant_id: str) -> bool: @@ -70,9 +68,7 @@ class BaseQueueDispatcher(ABC): Returns: True if quota consumed successfully, False if limit reached """ - return self.rate_limiter.check_and_consume( - tenant_id=tenant_id, max_daily_limit=self.get_daily_limit() - ) + return self.rate_limiter.check_and_consume(tenant_id=tenant_id, max_daily_limit=self.get_daily_limit()) class ProfessionalQueueDispatcher(BaseQueueDispatcher): diff --git a/api/services/workflow/rate_limiter.py b/api/services/workflow/rate_limiter.py index 6f83ed8ad4..dff284538a 100644 --- a/api/services/workflow/rate_limiter.py +++ b/api/services/workflow/rate_limiter.py @@ -74,10 +74,10 @@ class TenantDailyRateLimiter: Number of seconds until UTC midnight """ utc_now = datetime.utcnow() - + # Get next midnight in UTC next_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min) - + return int((next_midnight - utc_now).total_seconds()) def check_and_consume(self, tenant_id: str, max_daily_limit: int) -> bool: @@ -174,9 +174,9 @@ class TenantDailyRateLimiter: """ tz = pytz.timezone(timezone_str) utc_now = datetime.utcnow() - + # Get next midnight in UTC, then convert to tenant's timezone next_utc_midnight = datetime.combine(utc_now.date() + timedelta(days=1), time.min) next_utc_midnight = pytz.UTC.localize(next_utc_midnight) - + return next_utc_midnight.astimezone(tz)