From 72f9e773689f573de111072ebac6dfa94024ed09 Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 28 Aug 2025 15:53:48 +0800 Subject: [PATCH] refactor(trigger): clean up and optimize trigger-related code - Remove unused classes and imports in encryption utilities - Simplify method signatures for better readability - Enhance code quality by adding newlines for clarity - Update tests to reflect changes in import paths Co-authored-by: Claude --- api/core/helper/provider_cache.py | 2 + api/core/helper/provider_encryption.py | 125 ++++++++++++++++ api/core/tools/utils/encryption.py | 138 ++---------------- api/core/trigger/entities/entities.py | 3 + api/core/trigger/trigger_manager.py | 13 +- api/core/trigger/utils/encryption.py | 46 +++--- api/models/trigger.py | 1 + .../trigger/trigger_provider_service.py | 4 +- api/services/trigger_service.py | 2 +- .../core/tools/utils/test_encryption.py | 10 +- 10 files changed, 184 insertions(+), 160 deletions(-) create mode 100644 api/core/helper/provider_encryption.py diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index f641d399ab..ea5f4f0e4b 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -80,6 +80,7 @@ class TriggerProviderCredentialCache(ProviderCredentialsCache): credential_id = kwargs["credential_id"] return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}" + class TriggerProviderOAuthClientCache(ProviderCredentialsCache): """Cache for trigger provider OAuth client""" @@ -91,6 +92,7 @@ class TriggerProviderOAuthClientCache(ProviderCredentialsCache): provider_id = kwargs["provider_id"] return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}" + class NoOpProviderCredentialCache: """No-op provider credential cache""" diff --git a/api/core/helper/provider_encryption.py b/api/core/helper/provider_encryption.py new file mode 100644 index 0000000000..7f301833e9 --- /dev/null +++ b/api/core/helper/provider_encryption.py @@ -0,0 +1,125 @@ +import contextlib +from copy import deepcopy +from typing import Any, Optional, Protocol + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter + + +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: + tenant_id: str + config: list[BasicProviderConfig] + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cached_credentials = self.provider_config_cache.get() + if cached_credentials: + return cached_credentials + + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + with contextlib.suppress(Exception): + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + + self.provider_config_cache.set(data) + return data + + +def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): + return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 6cd58da9c7..3b6af302db 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -1,131 +1,23 @@ -import contextlib -from copy import deepcopy -from typing import Any, Optional, Protocol +# Import generic components from provider_encryption module +from core.helper.provider_encryption import ( + ProviderConfigCache, + ProviderConfigEncrypter, + create_provider_encrypter, +) -from core.entities.provider_entities import BasicProviderConfig -from core.helper import encrypter +# Re-export for backward compatibility +__all__ = [ + "ProviderConfigCache", + "ProviderConfigEncrypter", + "create_provider_encrypter", + "create_tool_provider_encrypter", +] + +# Tool-specific imports from core.helper.provider_cache import SingletonProviderCredentialsCache from core.tools.__base.tool_provider import ToolProviderController -class ProviderConfigCache(Protocol): - """ - Interface for provider configuration cache operations - """ - - def get(self) -> Optional[dict]: - """Get cached provider configuration""" - ... - - def set(self, config: dict[str, Any]) -> None: - """Cache provider configuration""" - ... - - def delete(self) -> None: - """Delete cached provider configuration""" - ... - - -class ProviderConfigEncrypter: - tenant_id: str - config: list[BasicProviderConfig] - provider_config_cache: ProviderConfigCache - - def __init__( - self, - tenant_id: str, - config: list[BasicProviderConfig], - provider_config_cache: ProviderConfigCache, - ): - self.tenant_id = tenant_id - self.config = config - self.provider_config_cache = provider_config_cache - - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: - """ - deep copy data - """ - return deepcopy(data) - - def encrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - encrypt tool credentials with tenant id - - return a deep copy of credentials with encrypted values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") - data[field_name] = encrypted - - return data - - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: - """ - mask tool credentials - - return a deep copy of credentials with masked values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - if len(data[field_name]) > 6: - data[field_name] = ( - data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] - ) - else: - data[field_name] = "*" * len(data[field_name]) - - return data - - def decrypt(self, data: dict[str, str]) -> dict[str, Any]: - """ - decrypt tool credentials with tenant id - - return a deep copy of credentials with decrypted values - """ - cached_credentials = self.provider_config_cache.get() - if cached_credentials: - return cached_credentials - - data = self._deep_copy(data) - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - with contextlib.suppress(Exception): - # if the value is None or empty string, skip decrypt - if not data[field_name]: - continue - - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - - self.provider_config_cache.set(data) - return data - -def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): - return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache - - def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController): cache = SingletonProviderCredentialsCache( tenant_id=tenant_id, diff --git a/api/core/trigger/entities/entities.py b/api/core/trigger/entities/entities.py index ca8dc07b67..b83487dad0 100644 --- a/api/core/trigger/entities/entities.py +++ b/api/core/trigger/entities/entities.py @@ -70,6 +70,7 @@ class TriggerIdentity(BaseModel): label: I18nObject = Field(..., description="The label of the trigger") provider: str = Field(..., description="The provider of the trigger") + class TriggerDescription(BaseModel): """ The description of the trigger @@ -91,12 +92,14 @@ class TriggerEntity(BaseModel): default=None, description="The output schema that this trigger produces" ) + class OAuthSchema(BaseModel): client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") credentials_schema: list[ProviderConfig] = Field( default_factory=list, description="The schema of the OAuth credentials" ) + class TriggerProviderEntity(BaseModel): """ The configuration of a trigger provider diff --git a/api/core/trigger/trigger_manager.py b/api/core/trigger/trigger_manager.py index 4863ea538a..58ae472d1b 100644 --- a/api/core/trigger/trigger_manager.py +++ b/api/core/trigger/trigger_manager.py @@ -51,9 +51,7 @@ class TriggerManager: return controllers @classmethod - def get_trigger_provider( - cls, tenant_id: str, provider_id: TriggerProviderID - ) -> PluginTriggerProviderController: + def get_trigger_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderController: """ Get a specific plugin trigger provider @@ -101,9 +99,7 @@ class TriggerManager: return provider.get_triggers() @classmethod - def get_trigger( - cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str - ) -> Optional[TriggerEntity]: + def get_trigger(cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str) -> Optional[TriggerEntity]: """ Get a specific trigger @@ -198,9 +194,7 @@ class TriggerManager: ) @classmethod - def get_provider_subscription_schema( - cls, tenant_id: str, provider_id: TriggerProviderID - ) -> list[ProviderConfig]: + def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]: """ Get provider subscription schema @@ -210,5 +204,6 @@ class TriggerManager: """ return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema() + # Export __all__ = ["TriggerManager"] diff --git a/api/core/trigger/utils/encryption.py b/api/core/trigger/utils/encryption.py index 19a76c8927..2abfa604c2 100644 --- a/api/core/trigger/utils/encryption.py +++ b/api/core/trigger/utils/encryption.py @@ -1,6 +1,8 @@ +from typing import Union + from core.helper.provider_cache import TriggerProviderCredentialCache, TriggerProviderOAuthClientCache +from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType -from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity from core.trigger.provider import PluginTriggerProviderController from models.trigger import TriggerProvider @@ -9,41 +11,47 @@ from models.trigger import TriggerProvider def create_trigger_provider_encrypter_for_credential( tenant_id: str, controller: PluginTriggerProviderController, - credential: TriggerProvider | TriggerProviderCredentialApiEntity, + credential: Union[TriggerProvider, TriggerProviderCredentialApiEntity], ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: - return create_provider_encrypter( + cache = TriggerProviderCredentialCache( + tenant_id=tenant_id, + provider_id=str(controller.get_provider_id()), + credential_id=credential.id, + ) + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=controller.get_credential_schema_config(credential.credential_type), - cache=TriggerProviderCredentialCache( - tenant_id=tenant_id, - provider_id=str(controller.get_provider_id()), - credential_id=credential.id, - ), + cache=cache, ) + return encrypter, cache def create_trigger_provider_encrypter( tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: - return create_provider_encrypter( + cache = TriggerProviderCredentialCache( + tenant_id=tenant_id, + provider_id=str(controller.get_provider_id()), + credential_id=credential_id, + ) + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=controller.get_credential_schema_config(credential_type), - cache=TriggerProviderCredentialCache( - tenant_id=tenant_id, - provider_id=str(controller.get_provider_id()), - credential_id=credential_id, - ), + cache=cache, ) + return encrypter, cache def create_trigger_provider_oauth_encrypter( tenant_id: str, controller: PluginTriggerProviderController ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: - return create_provider_encrypter( + cache = TriggerProviderOAuthClientCache( + tenant_id=tenant_id, + provider_id=str(controller.get_provider_id()), + ) + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()], - cache=TriggerProviderOAuthClientCache( - tenant_id=tenant_id, - provider_id=str(controller.get_provider_id()), - ), + cache=cache, ) + return encrypter, cache diff --git a/api/models/trigger.py b/api/models/trigger.py index 8941e44fa2..2db3a627dd 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -65,6 +65,7 @@ class TriggerProvider(Base): credentials=self.credentials, ) + # system level trigger oauth client params class TriggerOAuthSystemClient(Base): __tablename__ = "trigger_oauth_system_clients" diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 382cc18215..4628ec9c18 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -10,12 +10,10 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.provider_cache import NoOpProviderCredentialCache +from core.helper.provider_encryption import create_provider_encrypter from core.plugin.entities.plugin import TriggerProviderID from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.utils.encryption import ( - create_provider_encrypter, -) from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity from core.trigger.trigger_manager import TriggerManager diff --git a/api/services/trigger_service.py b/api/services/trigger_service.py index 434ef27aea..fa56e19773 100644 --- a/api/services/trigger_service.py +++ b/api/services/trigger_service.py @@ -20,4 +20,4 @@ class TriggerService: # TODO dispatch by the trigger controller # TODO using the dispatch result(events) to invoke the trigger events - raise NotImplementedError("Not implemented") \ No newline at end of file + raise NotImplementedError("Not implemented") diff --git a/api/tests/unit_tests/core/tools/utils/test_encryption.py b/api/tests/unit_tests/core/tools/utils/test_encryption.py index 6425ab0b8d..3b7c1f5678 100644 --- a/api/tests/unit_tests/core/tools/utils/test_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_encryption.py @@ -70,7 +70,7 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj data_in = {"username": "alice", "password": "plain_pwd"} data_copy = copy.deepcopy(data_in) - with patch("core.tools.utils.encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt: + with patch("core.helper.provider_encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt: out = encrypter_obj.encrypt(data_in) assert out["username"] == "alice" @@ -81,7 +81,7 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj def test_encrypt_missing_secret_key_is_ok(encrypter_obj): """If secret field missing in input, no error and no encryption called.""" - with patch("core.tools.utils.encryption.encrypter.encrypt_token") as mock_encrypt: + with patch("core.helper.provider_encryption.encrypter.encrypt_token") as mock_encrypt: out = encrypter_obj.encrypt({"username": "alice"}) assert out["username"] == "alice" mock_encrypt.assert_not_called() @@ -151,7 +151,7 @@ def test_decrypt_normal_flow(encrypter_obj): data_in = {"username": "alice", "password": "ENC"} data_copy = copy.deepcopy(data_in) - with patch("core.tools.utils.encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt: + with patch("core.helper.provider_encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt: out = encrypter_obj.decrypt(data_in) assert out["username"] == "alice" @@ -163,7 +163,7 @@ def test_decrypt_normal_flow(encrypter_obj): @pytest.mark.parametrize("empty_val", ["", None]) def test_decrypt_skip_empty_values(encrypter_obj, empty_val): """Skip decrypt if value is empty or None, keep original.""" - with patch("core.tools.utils.encryption.encrypter.decrypt_token") as mock_decrypt: + with patch("core.helper.provider_encryption.encrypter.decrypt_token") as mock_decrypt: out = encrypter_obj.decrypt({"password": empty_val}) mock_decrypt.assert_not_called() @@ -175,7 +175,7 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj): If decrypt_token raises, exception should be swallowed, and original value preserved. """ - with patch("core.tools.utils.encryption.encrypter.decrypt_token", side_effect=Exception("boom")): + with patch("core.helper.provider_encryption.encrypter.decrypt_token", side_effect=Exception("boom")): out = encrypter_obj.decrypt({"password": "ENC_ERR"}) assert out["password"] == "ENC_ERR"