mirror of https://github.com/langgenius/dify.git
refactor(trigger): clean up and optimize trigger-related code
- Remove unused classes and imports in encryption utilities - Simplify method signatures for better readability - Enhance code quality by adding newlines for clarity - Update tests to reflect changes in import paths Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
a46c9238fa
commit
72f9e77368
|
|
@ -80,6 +80,7 @@ class TriggerProviderCredentialCache(ProviderCredentialsCache):
|
|||
credential_id = kwargs["credential_id"]
|
||||
return f"trigger_credentials:tenant_id:{tenant_id}:provider_id:{provider_id}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class TriggerProviderOAuthClientCache(ProviderCredentialsCache):
|
||||
"""Cache for trigger provider OAuth client"""
|
||||
|
||||
|
|
@ -91,6 +92,7 @@ class TriggerProviderOAuthClientCache(ProviderCredentialsCache):
|
|||
provider_id = kwargs["provider_id"]
|
||||
return f"trigger_oauth_client:tenant_id:{tenant_id}:provider_id:{provider_id}"
|
||||
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
|
@ -1,131 +1,23 @@
|
|||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
# Import generic components from provider_encryption module
|
||||
from core.helper.provider_encryption import (
|
||||
ProviderConfigCache,
|
||||
ProviderConfigEncrypter,
|
||||
create_provider_encrypter,
|
||||
)
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"ProviderConfigCache",
|
||||
"ProviderConfigEncrypter",
|
||||
"create_provider_encrypter",
|
||||
"create_tool_provider_encrypter",
|
||||
]
|
||||
|
||||
# Tool-specific imports
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ class TriggerIdentity(BaseModel):
|
|||
label: I18nObject = Field(..., description="The label of the trigger")
|
||||
provider: str = Field(..., description="The provider of the trigger")
|
||||
|
||||
|
||||
class TriggerDescription(BaseModel):
|
||||
"""
|
||||
The description of the trigger
|
||||
|
|
@ -91,12 +92,14 @@ class TriggerEntity(BaseModel):
|
|||
default=None, description="The output schema that this trigger produces"
|
||||
)
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class TriggerProviderEntity(BaseModel):
|
||||
"""
|
||||
The configuration of a trigger provider
|
||||
|
|
|
|||
|
|
@ -51,9 +51,7 @@ class TriggerManager:
|
|||
return controllers
|
||||
|
||||
@classmethod
|
||||
def get_trigger_provider(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> PluginTriggerProviderController:
|
||||
def get_trigger_provider(cls, tenant_id: str, provider_id: TriggerProviderID) -> PluginTriggerProviderController:
|
||||
"""
|
||||
Get a specific plugin trigger provider
|
||||
|
||||
|
|
@ -101,9 +99,7 @@ class TriggerManager:
|
|||
return provider.get_triggers()
|
||||
|
||||
@classmethod
|
||||
def get_trigger(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str
|
||||
) -> Optional[TriggerEntity]:
|
||||
def get_trigger(cls, tenant_id: str, provider_id: TriggerProviderID, trigger_name: str) -> Optional[TriggerEntity]:
|
||||
"""
|
||||
Get a specific trigger
|
||||
|
||||
|
|
@ -198,9 +194,7 @@ class TriggerManager:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_provider_subscription_schema(
|
||||
cls, tenant_id: str, provider_id: TriggerProviderID
|
||||
) -> list[ProviderConfig]:
|
||||
def get_provider_subscription_schema(cls, tenant_id: str, provider_id: TriggerProviderID) -> list[ProviderConfig]:
|
||||
"""
|
||||
Get provider subscription schema
|
||||
|
||||
|
|
@ -210,5 +204,6 @@ class TriggerManager:
|
|||
"""
|
||||
return cls.get_trigger_provider(tenant_id, provider_id).get_subscription_schema()
|
||||
|
||||
|
||||
# Export
|
||||
__all__ = ["TriggerManager"]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Union
|
||||
|
||||
from core.helper.provider_cache import TriggerProviderCredentialCache, TriggerProviderOAuthClientCache
|
||||
from core.helper.provider_encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from core.trigger.entities.api_entities import TriggerProviderCredentialApiEntity
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from models.trigger import TriggerProvider
|
||||
|
|
@ -9,41 +11,47 @@ from models.trigger import TriggerProvider
|
|||
def create_trigger_provider_encrypter_for_credential(
|
||||
tenant_id: str,
|
||||
controller: PluginTriggerProviderController,
|
||||
credential: TriggerProvider | TriggerProviderCredentialApiEntity,
|
||||
credential: Union[TriggerProvider, TriggerProviderCredentialApiEntity],
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return create_provider_encrypter(
|
||||
cache = TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential.id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(credential.credential_type),
|
||||
cache=TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential.id,
|
||||
),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController, credential_id: str, credential_type: CredentialType
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return create_provider_encrypter(
|
||||
cache = TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential_id,
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=controller.get_credential_schema_config(credential_type),
|
||||
cache=TriggerProviderCredentialCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
credential_id=credential_id,
|
||||
),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
|
||||
def create_trigger_provider_oauth_encrypter(
|
||||
tenant_id: str, controller: PluginTriggerProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return create_provider_encrypter(
|
||||
cache = TriggerProviderOAuthClientCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
)
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_oauth_client_schema()],
|
||||
cache=TriggerProviderOAuthClientCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=str(controller.get_provider_id()),
|
||||
),
|
||||
cache=cache,
|
||||
)
|
||||
return encrypter, cache
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ class TriggerProvider(Base):
|
|||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
|
||||
# system level trigger oauth client params
|
||||
class TriggerOAuthSystemClient(Base):
|
||||
__tablename__ = "trigger_oauth_system_clients"
|
||||
|
|
|
|||
|
|
@ -10,12 +10,10 @@ from sqlalchemy.orm import Session
|
|||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.provider_encryption import create_provider_encrypter
|
||||
from core.plugin.entities.plugin import TriggerProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.utils.encryption import (
|
||||
create_provider_encrypter,
|
||||
)
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from core.trigger.entities.api_entities import TriggerProviderApiEntity, TriggerProviderCredentialApiEntity
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
|
|
|
|||
|
|
@ -20,4 +20,4 @@ class TriggerService:
|
|||
# TODO dispatch by the trigger controller
|
||||
|
||||
# TODO using the dispatch result(events) to invoke the trigger events
|
||||
raise NotImplementedError("Not implemented")
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj
|
|||
data_in = {"username": "alice", "password": "plain_pwd"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
with patch("core.tools.utils.encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
|
||||
out = encrypter_obj.encrypt(data_in)
|
||||
|
||||
assert out["username"] == "alice"
|
||||
|
|
@ -81,7 +81,7 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj
|
|||
|
||||
def test_encrypt_missing_secret_key_is_ok(encrypter_obj):
|
||||
"""If secret field missing in input, no error and no encryption called."""
|
||||
with patch("core.tools.utils.encryption.encrypter.encrypt_token") as mock_encrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.encrypt_token") as mock_encrypt:
|
||||
out = encrypter_obj.encrypt({"username": "alice"})
|
||||
assert out["username"] == "alice"
|
||||
mock_encrypt.assert_not_called()
|
||||
|
|
@ -151,7 +151,7 @@ def test_decrypt_normal_flow(encrypter_obj):
|
|||
data_in = {"username": "alice", "password": "ENC"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
with patch("core.tools.utils.encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
|
||||
out = encrypter_obj.decrypt(data_in)
|
||||
|
||||
assert out["username"] == "alice"
|
||||
|
|
@ -163,7 +163,7 @@ def test_decrypt_normal_flow(encrypter_obj):
|
|||
@pytest.mark.parametrize("empty_val", ["", None])
|
||||
def test_decrypt_skip_empty_values(encrypter_obj, empty_val):
|
||||
"""Skip decrypt if value is empty or None, keep original."""
|
||||
with patch("core.tools.utils.encryption.encrypter.decrypt_token") as mock_decrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token") as mock_decrypt:
|
||||
out = encrypter_obj.decrypt({"password": empty_val})
|
||||
|
||||
mock_decrypt.assert_not_called()
|
||||
|
|
@ -175,7 +175,7 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
|
|||
If decrypt_token raises, exception should be swallowed,
|
||||
and original value preserved.
|
||||
"""
|
||||
with patch("core.tools.utils.encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
|
||||
out = encrypter_obj.decrypt({"password": "ENC_ERR"})
|
||||
|
||||
assert out["password"] == "ENC_ERR"
|
||||
|
|
|
|||
Loading…
Reference in New Issue