mirror of https://github.com/langgenius/dify.git
feat(oauth): refactor tool encryption utils
This commit is contained in:
parent
eaefa1b7e6
commit
0dc5bfb2c7
|
|
@ -37,16 +37,23 @@ class ProviderCredentialsCache(ABC):
|
|||
redis_client.delete(self.cache_key)
|
||||
|
||||
|
||||
class GenericProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for generic provider credentials"""
|
||||
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for tool single provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, identity_id: str):
|
||||
super().__init__(tenant_id=tenant_id, identity_id=identity_id)
|
||||
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
|
||||
super().__init__(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=provider_type,
|
||||
provider_identity=provider_identity,
|
||||
)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
identity_id = kwargs["identity_id"]
|
||||
return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
provider_type = kwargs["provider_type"]
|
||||
identity_name = kwargs["provider_identity"]
|
||||
identity_id = f"{provider_type}.{identity_name}"
|
||||
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for tool provider credentials"""
|
||||
|
|
@ -58,7 +65,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
|||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,20 @@
|
|||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.plugin.entities.request import RequestInvokeEncrypt
|
||||
from core.tools.utils.encryption import create_generic_encrypter
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginEncrypter:
|
||||
@classmethod
|
||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
||||
encrypter, cache = create_generic_encrypter(
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=payload.config,
|
||||
provider_type=payload.namespace,
|
||||
provider_identity=payload.identity,
|
||||
cache=SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant.id,
|
||||
provider_type=payload.namespace,
|
||||
provider_identity=payload.identity,
|
||||
),
|
||||
)
|
||||
|
||||
if payload.opt == "encrypt":
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from core.tools.tool_label_manager import ToolLabelManager
|
|||
from core.tools.utils.configuration import (
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter, create_encrypter, create_generic_encrypter
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
|
|
@ -222,7 +222,7 @@ class ToolManager:
|
|||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
|
||||
|
||||
encrypter, _ = create_encrypter(
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
|
|
@ -248,11 +248,9 @@ class ToolManager:
|
|||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
encrypter, _ = create_generic_encrypter(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
||||
provider_type=api_provider.provider_type.value,
|
||||
provider_identity=api_provider.entity.identity.name,
|
||||
controller=api_provider,
|
||||
)
|
||||
return cast(
|
||||
ApiTool,
|
||||
|
|
@ -740,15 +738,12 @@ class ToolManager:
|
|||
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter.create_cached(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
controller=controller,
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
|
||||
|
||||
try:
|
||||
icon = json.loads(provider_obj.icon)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from typing import Any, Optional, Protocol
|
|||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import GenericProviderCredentialsCache
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
|
|
@ -123,13 +124,18 @@ class ProviderConfigEncrypter:
|
|||
return data
|
||||
|
||||
|
||||
def create_generic_encrypter(
|
||||
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
|
||||
):
|
||||
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
|
||||
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
|
||||
return encrypt, cache
|
||||
|
||||
|
||||
def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
encrypt = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_config_cache=cache,
|
||||
)
|
||||
return encrypt, cache
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
|||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.plugin.impl.dynamic_select import DynamicSelectClient
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.tools import BuiltinToolProvider
|
||||
|
||||
|
|
@ -38,11 +38,9 @@ class PluginParameterService:
|
|||
case "tool":
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
# check if credentials are required
|
||||
|
|
@ -63,7 +61,7 @@ class PluginParameterService:
|
|||
if db_record is None:
|
||||
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
|
||||
|
||||
credentials = tool_configuration.decrypt(db_record.credentials)
|
||||
credentials = encrypter.decrypt(db_record.credentials)
|
||||
case _:
|
||||
raise ValueError(f"Invalid provider type: {provider_type}")
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
|
|||
)
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter
|
||||
from core.tools.utils.encryption import create_tool_provider_encrypter
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
|
|
@ -164,15 +164,11 @@ class ApiToolManageService:
|
|||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
||||
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
|
||||
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
|
|
@ -297,11 +293,9 @@ class ApiToolManageService:
|
|||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
encrypter, cache = create_generic_encrypter(
|
||||
encrypter, cache = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
original_credentials = encrypter.decrypt(provider.credentials)
|
||||
|
|
@ -416,11 +410,9 @@ class ApiToolManageService:
|
|||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
encrypter, _ = create_generic_encrypter(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
controller=provider_controller,
|
||||
)
|
||||
decrypted_credentials = encrypter.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
|
|||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.encryption import create_encrypter
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||
|
|
@ -225,7 +225,7 @@ class BuiltinToolManageService:
|
|||
provider: str,
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
):
|
||||
encrypter, cache = create_encrypter(
|
||||
encrypter, cache = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
|
|
@ -396,7 +396,7 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
encrypter, _ = create_encrypter(
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
|
|
@ -608,7 +608,7 @@ class BuiltinToolManageService:
|
|||
session.add(custom_client_params)
|
||||
|
||||
if client_params is not None:
|
||||
encrypter, _ = create_encrypter(
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
|
|
@ -647,7 +647,7 @@ class BuiltinToolManageService:
|
|||
if not isinstance(provider_controller, BuiltinToolProviderController):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
encrypter, _ = create_encrypter(
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from core.tools.entities.tool_entities import (
|
|||
ToolProviderType,
|
||||
)
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.encryption import create_encrypter, create_generic_encrypter
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
|
|
@ -113,9 +113,7 @@ class ToolTransformService:
|
|||
schema = {
|
||||
x.to_basic_provider_config().name: x
|
||||
for x in provider_controller.get_credentials_schema_by_type(
|
||||
CredentialType.of(db_provider.credential_type)
|
||||
if db_provider
|
||||
else CredentialType.API_KEY
|
||||
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -134,7 +132,7 @@ class ToolTransformService:
|
|||
credentials = db_provider.credentials
|
||||
|
||||
# init tool configuration
|
||||
encrypter, _ = create_encrypter(
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
|
|
@ -252,11 +250,9 @@ class ToolTransformService:
|
|||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
encrypter, _ = create_generic_encrypter(
|
||||
encrypter, _ = create_tool_provider_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
controller=provider_controller,
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
|
|
|
|||
Loading…
Reference in New Issue