refactor(trigger): clean up and optimize trigger-related code

- Remove unused classes and imports in encryption utilities
- Simplify method signatures for better readability
- Enhance code quality by adding newlines for clarity
- Update tests to reflect changes in import paths

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Harry 2025-08-28 15:53:48 +08:00
parent a46c9238fa
commit 72f9e77368
10 changed files with 184 additions and 160 deletions

View File

@ -80,6 +80,7 @@ class TriggerProviderCredentialCache(ProviderCredentialsCache):
credential_id = kwargs["credential_id"]
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
class TriggerProviderOAuthClientCache(ProviderCredentialsCache):
"""Cache for trigger provider OAuth client"""
@ -91,6 +92,7 @@ class TriggerProviderOAuthClientCache(ProviderCredentialsCache):
provider_id = kwargs["provider_id"]
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""

View File

@ -0,0 +1,125 @@
import contextlib
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache

View File

@ -1,131 +1,23 @@
import contextlib
from copy import deepcopy
from typing import Any, Optional, Protocol
# Import generic components from provider_encryption module
from core.helper.provider_encryption import (
ProviderConfigCache,
ProviderConfigEncrypter,
create_provider_encrypter,
)
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
# Re-export for backward compatibility
__all__ = [
"ProviderConfigCache",
"ProviderConfigEncrypter",
"create_provider_encrypter",
"create_tool_provider_encrypter",
]
# Tool-specific imports
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,

View File

@ -70,6 +70,7 @@ class TriggerIdentity(BaseModel):
label: I18nObject = Field(..., description="The label of the trigger")
provider: str = Field(..., description="The provider of the trigger")
class TriggerDescription(BaseModel):
"""
The description of the trigger
@ -91,12 +92,14 @@ class TriggerEntity(BaseModel):
default=None, description="The output schema that this trigger produces"
)
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class TriggerProviderEntity(BaseModel):
"""
The configuration of a trigger provider

View File

@ -51,9 +51,7 @@ class TriggerManager:
return controllers
@classmethod
def get_trigger_provider(
cls, tenant_id: str, provider_id: TriggerProviderID
) -> PluginTriggerProviderController:
def get_trigger_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderController:
"""
Get a specific plugin trigger provider
@ -101,9 +99,7 @@ class TriggerManager:
return provider.get_triggers()
@classmethod
def get_trigger(
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str
) -> Optional[TriggerEntity]:
def get_trigger(cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str) -> Optional[TriggerEntity]:
"""
Get a specific trigger
@ -198,9 +194,7 @@ class TriggerManager:
)
@classmethod
def get_provider_subscription_schema(
cls, tenant_id: str, provider_id: TriggerProviderID
) -> list[ProviderConfig]:
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]:
"""
Get provider subscription schema
@ -210,5 +204,6 @@ class TriggerManager:
"""
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
# Export
__all__ = ["TriggerManager"]

View File

@ -1,6 +1,8 @@
from typing import Union
from core.helper.provider_cache import TriggerProviderCredentialCache, TriggerProviderOAuthClientCache
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
from core.trigger.provider import PluginTriggerProviderController
from models.trigger import TriggerProvider
@ -9,41 +11,47 @@ from models.trigger import TriggerProvider
def create_trigger_provider_encrypter_for_credential(
tenant_id: str,
controller: PluginTriggerProviderController,
credential: TriggerProvider | TriggerProviderCredentialApiEntity,
credential: Union[TriggerProvider, TriggerProviderCredentialApiEntity],
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
cache = TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential.id,
)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(credential.credential_type),
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential.id,
),
cache=cache,
)
return encrypter, cache
def create_trigger_provider_encrypter(
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
cache = TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential_id,
)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=controller.get_credential_schema_config(credential_type),
cache=TriggerProviderCredentialCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
credential_id=credential_id,
),
cache=cache,
)
return encrypter, cache
def create_trigger_provider_oauth_encrypter(
tenant_id: str, controller: PluginTriggerProviderController
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
return create_provider_encrypter(
cache = TriggerProviderOAuthClientCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
cache=TriggerProviderOAuthClientCache(
tenant_id=tenant_id,
provider_id=str(controller.get_provider_id()),
),
cache=cache,
)
return encrypter, cache

View File

@ -65,6 +65,7 @@ class TriggerProvider(Base):
credentials=self.credentials,
)
# system level trigger oauth client params
class TriggerOAuthSystemClient(Base):
__tablename__ = "trigger_oauth_system_clients"

View File

@ -10,12 +10,10 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import create_provider_encrypter
from core.plugin.entities.plugin import TriggerProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.encryption import (
create_provider_encrypter,
)
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity
from core.trigger.trigger_manager import TriggerManager

View File

@ -20,4 +20,4 @@ class TriggerService:
# TODO dispatch by the trigger controller
# TODO using the dispatch result(events) to invoke the trigger events
raise NotImplementedError("Not implemented")
raise NotImplementedError("Not implemented")

View File

@ -70,7 +70,7 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj
data_in = {"username": "alice", "password": "plain_pwd"}
data_copy = copy.deepcopy(data_in)
with patch("core.tools.utils.encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
with patch("core.helper.provider_encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
out = encrypter_obj.encrypt(data_in)
assert out["username"] == "alice"
@ -81,7 +81,7 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj
def test_encrypt_missing_secret_key_is_ok(encrypter_obj):
"""If secret field missing in input, no error and no encryption called."""
with patch("core.tools.utils.encryption.encrypter.encrypt_token") as mock_encrypt:
with patch("core.helper.provider_encryption.encrypter.encrypt_token") as mock_encrypt:
out = encrypter_obj.encrypt({"username": "alice"})
assert out["username"] == "alice"
mock_encrypt.assert_not_called()
@ -151,7 +151,7 @@ def test_decrypt_normal_flow(encrypter_obj):
data_in = {"username": "alice", "password": "ENC"}
data_copy = copy.deepcopy(data_in)
with patch("core.tools.utils.encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
with patch("core.helper.provider_encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
out = encrypter_obj.decrypt(data_in)
assert out["username"] == "alice"
@ -163,7 +163,7 @@ def test_decrypt_normal_flow(encrypter_obj):
@pytest.mark.parametrize("empty_val", ["", None])
def test_decrypt_skip_empty_values(encrypter_obj, empty_val):
"""Skip decrypt if value is empty or None, keep original."""
with patch("core.tools.utils.encryption.encrypter.decrypt_token") as mock_decrypt:
with patch("core.helper.provider_encryption.encrypter.decrypt_token") as mock_decrypt:
out = encrypter_obj.decrypt({"password": empty_val})
mock_decrypt.assert_not_called()
@ -175,7 +175,7 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
If decrypt_token raises, exception should be swallowed,
and original value preserved.
"""
with patch("core.tools.utils.encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
with patch("core.helper.provider_encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
out = encrypter_obj.decrypt({"password": "ENC_ERR"})
assert out["password"] == "ENC_ERR"