diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 5da20c3d29..090d5f3cee 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -82,7 +82,7 @@ class ToolBuiltinProviderInfoApi(Resource): user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) + return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) class ToolBuiltinProviderDeleteApi(Resource): @@ -159,7 +159,7 @@ class ToolBuiltinProviderUpdateApi(Resource): result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, - provider_name=provider, + provider=provider, credentials=args["credentials"], credential_id=args["credential_id"], name=args["name"], @@ -782,7 +782,6 @@ class ToolOAuthCustomClient(Resource): return BuiltinToolManageService.setup_oauth_custom_client( tenant_id=user.current_tenant_id, - user_id=user.id, provider=provider, client_params=args["client_params"], ) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py new file mode 100644 index 0000000000..3e70ea5341 --- /dev/null +++ b/api/core/helper/provider_cache.py @@ -0,0 +1,77 @@ +import json +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any, Optional + +from extensions.ext_redis import redis_client + + +class ProviderCredentialsCache(ABC): + """Base class for provider credentials cache""" + + def __init__(self, **kwargs): + self.cache_key = self._generate_cache_key(**kwargs) + + @abstractmethod + def _generate_cache_key(self, **kwargs) -> str: + """Generate cache key based on subclass implementation""" + pass + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + cached_credentials = redis_client.get(self.cache_key) + if cached_credentials: + try: + cached_credentials = cached_credentials.decode("utf-8") + return dict(json.loads(cached_credentials)) + except JSONDecodeError: + return None + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + redis_client.setex(self.cache_key, 86400, json.dumps(config)) + + def delete(self) -> None: + """Delete cached provider credentials""" + redis_client.delete(self.cache_key) + + +class GenericProviderCredentialsCache(ProviderCredentialsCache): + """Cache for generic provider credentials""" + + def __init__(self, tenant_id: str, identity_id: str): + super().__init__(tenant_id=tenant_id, identity_id=identity_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + identity_id = kwargs["identity_id"] + return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}" + +class ToolProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool provider credentials""" + + def __init__(self, tenant_id: str, provider: str, credential_id: str): + super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider = kwargs["provider"] + credential_id = kwargs["credential_id"] + return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + + +class NoOpProviderCredentialCache: + """No-op provider credential cache""" + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + pass + + def delete(self) -> None: + """Delete cached provider credentials""" + pass diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index 2e4a04c579..0000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from enum import Enum -from json import JSONDecodeError -from typing import Optional - -from extensions.ext_redis import redis_client - - -class ToolProviderCredentialsCacheType(Enum): - PROVIDER = "tool_provider" - ENDPOINT = "endpoint" - - -class ToolProviderCredentialsCache: - def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - - def get(self) -> Optional[dict]: - """ - Get cached model provider credentials. - - :return: - """ - cached_provider_credentials = redis_client.get(self.cache_key) - if cached_provider_credentials: - try: - cached_provider_credentials = cached_provider_credentials.decode("utf-8") - cached_provider_credentials = json.loads(cached_provider_credentials) - except JSONDecodeError: - return None - - return dict(cached_provider_credentials) - else: - return None - - def set(self, credentials: dict) -> None: - """ - Cache model provider credentials. - - :param credentials: provider credentials - :return: - """ - redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) - - def delete(self) -> None: - """ - Delete cached model provider credentials. - - :return: - """ - redis_client.delete(self.cache_key) diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index 81a5d033a0..bfe9ffa4b0 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,12 +1,12 @@ from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import create_generic_encrypter from models.account import Tenant class PluginEncrypter: @classmethod def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: - encrypter = ProviderConfigEncrypter( + encrypter, cache = create_generic_encrypter( tenant_id=tenant.id, config=payload.config, provider_type=payload.namespace, @@ -22,7 +22,7 @@ class PluginEncrypter: "data": encrypter.decrypt(payload.data), } elif payload.opt == "clear": - encrypter.delete_tool_credentials_cache() + cache.delete() return { "data": {}, } diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 9e3c13849f..53affe9e97 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -105,20 +105,34 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.tools - def get_credentials_schema( - self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY - ) -> list[ProviderConfig]: + def get_credentials_schema(self) -> list[ProviderConfig]: """ returns the credentials schema of the provider :return: the credentials schema """ - if credential_type == ToolProviderCredentialType.OAUTH2: + return self.get_credentials_schema_by_type(ToolProviderCredentialType.API_KEY.value) + + def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + """ + returns the credentials schema of the provider + + :param credential_type: the type of the credential + :return: the credentials schema of the provider + """ + if credential_type == ToolProviderCredentialType.OAUTH2.value: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] - elif credential_type == ToolProviderCredentialType.API_KEY: + if credential_type == ToolProviderCredentialType.API_KEY.value: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] - else: - raise ValueError(f"Invalid credential type: {credential_type}") + raise ValueError(f"Invalid credential type: {credential_type}") + + def get_oauth_client_schema(self) -> list[ProviderConfig]: + """ + returns the oauth client schema of the provider + + :return: the oauth client schema + """ + return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] def get_tools(self) -> list[BuiltinTool]: """ @@ -141,7 +155,11 @@ class BuiltinToolProviderController(ToolProviderController): :return: whether the provider needs credentials """ - return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + return ( + self.entity.credentials_schema is not None + and len(self.entity.credentials_schema) != 0 + or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0) + ) @property def provider_type(self) -> ToolProviderType: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 35d4eb0c7e..e9423a6c49 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL import contexts +from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -38,12 +39,16 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolInvokeFrom, ToolParameter, - ToolProviderCredentialType, ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager +from core.tools.utils.configuration import ( + ProviderConfigEncrypter, + ToolParameterConfigurationManager, + create_encrypter, + create_generic_encrypter, +) from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -206,19 +211,18 @@ class ToolManager: # decrypt the credentials credentials = builtin_provider.credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() - for x in provider_controller.get_credentials_schema( - ToolProviderCredentialType.of(builtin_provider.credential_type) - ) + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) ], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), ) - decrypted_credentials = tool_configuration.decrypt(credentials) + decrypted_credentials = encrypter.decrypt(credentials) return cast( BuiltinTool, @@ -235,22 +239,18 @@ class ToolManager: elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - - # decrypt the credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_generic_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], provider_type=api_provider.provider_type.value, provider_identity=api_provider.entity.identity.name, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(credentials), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, ) @@ -730,7 +730,7 @@ class ToolManager: ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + tool_configuration = ProviderConfigEncrypter.create_cached( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], provider_type=controller.provider_type.value, diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 6a5fba65bd..2b64703321 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,12 +1,10 @@ from copy import deepcopy -from typing import Any - -from pydantic import BaseModel +from typing import Any, Optional, Protocol from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter +from core.helper.provider_cache import GenericProviderCredentialsCache from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType -from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( ToolParameter, @@ -14,11 +12,38 @@ from core.tools.entities.tool_entities import ( ) -class ProviderConfigEncrypter(BaseModel): +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: tenant_id: str config: list[BasicProviderConfig] - provider_type: str - provider_identity: str + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: """ @@ -72,18 +97,13 @@ class ProviderConfigEncrypter(BaseModel): return data - def decrypt(self, data: dict[str, str]) -> dict[str, str]: + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: """ decrypt tool credentials with tenant id return a deep copy of credentials with decrypted values """ - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cached_credentials = cache.get() + cached_credentials = self.provider_config_cache.get() if cached_credentials: return cached_credentials data = self._deep_copy(data) @@ -104,16 +124,24 @@ class ProviderConfigEncrypter(BaseModel): except Exception: pass - cache.set(data) + self.provider_config_cache.set(data) return data - def delete_tool_credentials_cache(self): - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cache.delete() + +def create_encrypter( + tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache +): + return ProviderConfigEncrypter( + tenant_id=tenant_id, config=config, provider_config_cache=cache + ), cache + + +def create_generic_encrypter( + tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str +): + cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") + encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) + return encrypt, cache class ToolParameterConfigurationManager: diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index b429851349..ff84b4318b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider @@ -297,28 +297,28 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ProviderConfigEncrypter( + encrypter, cache = create_generic_encrypter( tenant_id=tenant_id, config=list(provider_controller.get_credentials_schema()), provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = original_credentials[name] - credentials = tool_configuration.encrypt(credentials) + credentials = encrypter.encrypt(credentials) provider.credentials_str = json.dumps(credentials) db.session.add(provider) db.session.commit() # delete cache - tool_configuration.delete_tool_credentials_cache() + cache.delete() # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) @@ -416,15 +416,15 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_generic_encrypter( tenant_id=tenant_id, config=list(provider_controller.get_credentials_schema()), provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - decrypted_credentials = tool_configuration.decrypt(credentials) + decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials) for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = decrypted_credentials[name] diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 80ee9b080c..17c1a4b421 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -8,19 +8,18 @@ from sqlalchemy.orm import Session from configs import dify_config from core.helper.position_helper import is_filtered +from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError -from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError -from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import create_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient @@ -58,20 +57,15 @@ class BuiltinToolManageService: return result @staticmethod - def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_info(tenant_id: str, provider: str): """ get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_configuration.decrypt(credentials) + if builtin_provider is None: + raise ValueError(f"you have not added provider {provider}") entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -80,7 +74,6 @@ class BuiltinToolManageService: ) entity.original_credentials = {} - return entity @staticmethod @@ -96,32 +89,34 @@ class BuiltinToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema(credential_type)) + return jsonable_encoder(provider.get_credentials_schema_by_type(credential_type)) @staticmethod def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None + user_id: str, tenant_id: str, provider: str, credentials: dict, credential_id: str, name: str | None = None ): """ update builtin tool provider """ # get if the provider exists - provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) + db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") try: - if ToolProviderCredentialType.of(provider.credential_type).is_editable(): - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if ToolProviderCredentialType.of(db_provider.credential_type).is_editable(): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") + raise ValueError(f"provider {provider} does not need credentials") - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) # Decrypt and restore original credentials for masked values - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + original_credentials = encrypter.decrypt(db_provider.credentials) + masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for key, value in credentials.items(): @@ -131,13 +126,13 @@ class BuiltinToolManageService: provider_controller.validate_credentials(user_id, credentials) # encrypt credentials - encrypted_credentials = tool_configuration.encrypt(credentials) - provider.encrypted_credentials = json.dumps(encrypted_credentials) - tool_configuration.delete_tool_credentials_cache() + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) + + cache.delete() # update name if provided - if name is not None and provider.name != name: - provider.name = name + if name is not None and db_provider.name != name: + db_provider.name = name db.session.commit() except ( @@ -176,7 +171,7 @@ class BuiltinToolManageService: name if name else BuiltinToolManageService.generate_builtin_tool_provider_name( - tenant_id, provider, credential_type=api_type + tenant_id=tenant_id, provider=provider, credential_type=api_type ) ) @@ -193,20 +188,35 @@ class BuiltinToolManageService: if not provider_controller.need_credentials: raise ValueError(f"provider {provider} does not need credentials") - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - - # Encrypt and save the credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller=provider_controller, - tool_configuration=tool_configuration, - provider=db_provider, - credentials=credentials, - user_id=user_id, + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller ) + + # encrypt credentials + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) + + cache.delete() db.session.add(db_provider) db.session.commit() return {"result": "success"} + @staticmethod + def create_tool_encrypter( + tenant_id: str, + db_provider: BuiltinToolProvider, + provider: str, + provider_controller: BuiltinToolProviderController, + ): + encrypter, cache = create_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type) + ], + cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id), + ) + return encrypter, cache + @staticmethod def generate_builtin_tool_provider_name( tenant_id: str, provider: str, credential_type: ToolProviderCredentialType @@ -273,12 +283,13 @@ class BuiltinToolManageService: default_provider.is_default = True provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, default_provider, default_provider.provider, provider_controller + ) + credentials: list[ToolProviderCredentialApiEntity] = [] for provider in providers: - decrypt_credential = tool_configuration.mask_tool_credentials( - tool_configuration.decrypt(provider.credentials) - ) + decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( provider=provider, credentials=decrypt_credential, @@ -287,22 +298,24 @@ class BuiltinToolManageService: return credentials @staticmethod - def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str): + def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str): """ delete tool provider """ tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if tool_provider is None: - raise ValueError(f"you have not added provider {provider_name}") + raise ValueError(f"you have not added provider {provider}") db.session.delete(tool_provider) db.session.commit() # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - tool_configuration.delete_tool_credentials_cache() + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + _, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, tool_provider, provider, provider_controller + ) + cache.delete() return {"result": "success"} @@ -493,57 +506,35 @@ class BuiltinToolManageService: ) @staticmethod - def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController): - return ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - - @staticmethod - def _encrypt_and_save_credentials( - provider_controller: BuiltinToolProviderController | PluginToolProviderController, - tool_configuration: ProviderConfigEncrypter, - provider: BuiltinToolProvider, - credentials: dict, - user_id: str, - ): - """ - Validate and encrypt credentials, then save to database - - :param provider_controller: the provider controller - :param tool_configuration: the tool configuration encrypter - :param provider: the provider object from database - :param credentials: the credentials to encrypt and save - :param user_id: the user id for validation - """ - if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed(): - provider_controller.validate_credentials(user_id, credentials) - - # encrypt credentials - encrypted_credentials = tool_configuration.encrypt(credentials) - provider.encrypted_credentials = json.dumps(encrypted_credentials) - tool_configuration.delete_tool_credentials_cache() - - @staticmethod - def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict): + def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict): """ setup oauth custom client """ - provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - if not provider_controller: - raise ToolProviderNotFoundError(f"Provider {provider} not found") + with Session(db.engine) as session: + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + if not isinstance(provider_controller, BuiltinToolProviderController): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") - # Validate and encrypt credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller=provider_controller, - tool_configuration=tool_configuration, - provider=None, # No need to save in DB - credentials=client_params, - user_id=user_id, - ) + encrypter, _ = create_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + # encrypt credentials + encrypted_credentials = encrypter.encrypt(client_params) + session.add( + ToolOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + enabled=True, + encrypted_oauth_params=json.dumps(encrypted_credentials), + ) + ) + session.commit() return {"result": "success"} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 160352c4c0..1c3ef3d48c 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,7 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.helper.provider_cache import ToolProviderCredentialsCache from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -19,7 +20,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import create_encrypter, create_generic_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -109,7 +110,14 @@ class ToolTransformService: result.plugin_unique_identifier = provider_controller.plugin_unique_identifier # get credentials schema - schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + schema = { + x.to_basic_provider_config().name: x + for x in provider_controller.get_credentials_schema_by_type( + ToolProviderCredentialType.of(db_provider.credential_type) + if db_provider + else ToolProviderCredentialType.API_KEY + ) + } for name, value in schema.items(): if result.masked_credentials: @@ -126,15 +134,23 @@ class ToolTransformService: credentials = db_provider.credentials # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type( + ToolProviderCredentialType.of(db_provider.credential_type) + ) + ], + cache=ToolProviderCredentialsCache( + tenant_id=db_provider.tenant_id, + provider=db_provider.provider, + credential_id=db_provider.id, + ), ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials result.original_credentials = decrypted_credentials @@ -236,7 +252,7 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_generic_encrypter( tenant_id=db_provider.tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, @@ -244,8 +260,8 @@ class ToolTransformService: ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials @@ -264,7 +280,7 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials= {}, + credentials={}, tenant_id=tenant_id, ) )