mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
feat(oauth): refactor encryption
This commit is contained in:
parent
9f053f3bbc
commit
eaefa1b7e6
@ -1,5 +1,5 @@
|
|||||||
from core.plugin.entities.request import RequestInvokeEncrypt
|
from core.plugin.entities.request import RequestInvokeEncrypt
|
||||||
from core.tools.utils.configuration import create_generic_encrypter
|
from core.tools.utils.encryption import create_generic_encrypter
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -45,11 +45,9 @@ from core.tools.entities.tool_entities import (
|
|||||||
from core.tools.errors import ToolProviderNotFoundError
|
from core.tools.errors import ToolProviderNotFoundError
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import (
|
from core.tools.utils.configuration import (
|
||||||
ProviderConfigEncrypter,
|
|
||||||
ToolParameterConfigurationManager,
|
ToolParameterConfigurationManager,
|
||||||
create_encrypter,
|
|
||||||
create_generic_encrypter,
|
|
||||||
)
|
)
|
||||||
|
from core.tools.utils.encryption import ProviderConfigEncrypter, create_encrypter, create_generic_encrypter
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
|
|||||||
@ -1,9 +1,7 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Optional, Protocol
|
from typing import Any
|
||||||
|
|
||||||
from core.entities.provider_entities import BasicProviderConfig
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.helper.provider_cache import GenericProviderCredentialsCache
|
|
||||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
@ -12,139 +10,6 @@ from core.tools.entities.tool_entities import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfigCache(Protocol):
|
|
||||||
"""
|
|
||||||
Interface for provider configuration cache operations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get(self) -> Optional[dict]:
|
|
||||||
"""Get cached provider configuration"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def set(self, config: dict[str, Any]) -> None:
|
|
||||||
"""Cache provider configuration"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def delete(self) -> None:
|
|
||||||
"""Delete cached provider configuration"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderConfigEncrypter:
|
|
||||||
tenant_id: str
|
|
||||||
config: list[BasicProviderConfig]
|
|
||||||
provider_config_cache: ProviderConfigCache
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tenant_id: str,
|
|
||||||
config: list[BasicProviderConfig],
|
|
||||||
provider_config_cache: ProviderConfigCache,
|
|
||||||
):
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.config = config
|
|
||||||
self.provider_config_cache = provider_config_cache
|
|
||||||
|
|
||||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
deep copy data
|
|
||||||
"""
|
|
||||||
return deepcopy(data)
|
|
||||||
|
|
||||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
|
||||||
"""
|
|
||||||
encrypt tool credentials with tenant id
|
|
||||||
|
|
||||||
return a deep copy of credentials with encrypted values
|
|
||||||
"""
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
|
||||||
data[field_name] = encrypted
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
mask tool credentials
|
|
||||||
|
|
||||||
return a deep copy of credentials with masked values
|
|
||||||
"""
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
if len(data[field_name]) > 6:
|
|
||||||
data[field_name] = (
|
|
||||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data[field_name] = "*" * len(data[field_name])
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
decrypt tool credentials with tenant id
|
|
||||||
|
|
||||||
return a deep copy of credentials with decrypted values
|
|
||||||
"""
|
|
||||||
cached_credentials = self.provider_config_cache.get()
|
|
||||||
if cached_credentials:
|
|
||||||
return cached_credentials
|
|
||||||
|
|
||||||
data = self._deep_copy(data)
|
|
||||||
# get fields need to be decrypted
|
|
||||||
fields = dict[str, BasicProviderConfig]()
|
|
||||||
for credential in self.config:
|
|
||||||
fields[credential.name] = credential
|
|
||||||
|
|
||||||
for field_name, field in fields.items():
|
|
||||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
|
||||||
if field_name in data:
|
|
||||||
try:
|
|
||||||
# if the value is None or empty string, skip decrypt
|
|
||||||
if not data[field_name]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.provider_config_cache.set(data)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def create_encrypter(
|
|
||||||
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
|
|
||||||
):
|
|
||||||
return ProviderConfigEncrypter(
|
|
||||||
tenant_id=tenant_id, config=config, provider_config_cache=cache
|
|
||||||
), cache
|
|
||||||
|
|
||||||
|
|
||||||
def create_generic_encrypter(
|
|
||||||
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
|
|
||||||
):
|
|
||||||
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
|
|
||||||
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
|
|
||||||
return encrypt, cache
|
|
||||||
|
|
||||||
|
|
||||||
class ToolParameterConfigurationManager:
|
class ToolParameterConfigurationManager:
|
||||||
"""
|
"""
|
||||||
Tool parameter configuration manager
|
Tool parameter configuration manager
|
||||||
|
|||||||
135
api/core/tools/utils/encryption.py
Normal file
135
api/core/tools/utils/encryption.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.helper.provider_cache import GenericProviderCredentialsCache
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigCache(Protocol):
|
||||||
|
"""
|
||||||
|
Interface for provider configuration cache operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""Get cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def set(self, config: dict[str, Any]) -> None:
|
||||||
|
"""Cache provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""Delete cached provider configuration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigEncrypter:
|
||||||
|
tenant_id: str
|
||||||
|
config: list[BasicProviderConfig]
|
||||||
|
provider_config_cache: ProviderConfigCache
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
config: list[BasicProviderConfig],
|
||||||
|
provider_config_cache: ProviderConfigCache,
|
||||||
|
):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.config = config
|
||||||
|
self.provider_config_cache = provider_config_cache
|
||||||
|
|
||||||
|
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
deep copy data
|
||||||
|
"""
|
||||||
|
return deepcopy(data)
|
||||||
|
|
||||||
|
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
encrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with encrypted values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||||
|
data[field_name] = encrypted
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
mask tool credentials
|
||||||
|
|
||||||
|
return a deep copy of credentials with masked values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
if len(data[field_name]) > 6:
|
||||||
|
data[field_name] = (
|
||||||
|
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data[field_name] = "*" * len(data[field_name])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
decrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with decrypted values
|
||||||
|
"""
|
||||||
|
cached_credentials = self.provider_config_cache.get()
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
|
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
try:
|
||||||
|
# if the value is None or empty string, skip decrypt
|
||||||
|
if not data[field_name]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.provider_config_cache.set(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_generic_encrypter(
|
||||||
|
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
|
||||||
|
):
|
||||||
|
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
|
||||||
|
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
|
||||||
|
return encrypt, cache
|
||||||
|
|
||||||
|
|
||||||
|
def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||||
|
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||||||
from core.plugin.entities.parameters import PluginParameterOption
|
from core.plugin.entities.parameters import PluginParameterOption
|
||||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import BuiltinToolProvider
|
from models.tools import BuiltinToolProvider
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
)
|
)
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter
|
from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter
|
||||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider
|
from models.tools import ApiToolProvider
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
|
|||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.configuration import create_encrypter
|
from core.tools.utils.encryption import create_encrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.tools.utils.configuration import create_encrypter, create_generic_encrypter
|
from core.tools.utils.encryption import create_encrypter, create_generic_encrypter
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user