diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 5d499c0d07..b02008339e 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -76,13 +76,27 @@ class ToolBuiltinProviderUpdateApi(Resource): provider, args['credentials'], ) + +class ToolBuiltinProviderGetCredentialsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + return ToolManageService.get_builtin_tool_provider_credentials( + user_id, + tenant_id, + provider, + ) class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): - icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider) + icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider) icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE')) - return send_file(io.BytesIO(icon_bytes), mimetype=minetype, max_age=icon_cache_max_age) + return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) class ToolModelProviderIconApi(Resource): @setup_required @@ -333,6 +347,7 @@ api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') +api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model//icon') @@ -340,7 +355,7 @@ api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provid api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') -api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') +api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index f72a757bc8..c2178cdd40 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -127,7 +127,8 @@ class BuiltinToolProviderController(ToolProviderController): :return: whether the provider needs credentials """ - return self.credentials_schema is not None and len(self.credentials_schema) != 0 + return self.credentials_schema is not None and \ + len(self.credentials_schema) != 0 @property def app_type(self) -> ToolProviderType: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index c5831876c8..f3fa18393e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -3,33 +3,29 @@ import logging import mimetypes from collections.abc import Generator from os import listdir, path +from threading import Lock from typing import Any, Union from flask import current_app from core.agent.entities import AgentToolEntity -from core.model_runtime.entities.message_entities import PromptMessage from core.provider_manager import ProviderManager +from core.tools import * from core.tools.entities.common_entities import I18nObject -from core.tools.entities.constant import DEFAULT_PROVIDERS from core.tools.entities.tool_entities import ( ApiProviderAuthType, - ToolInvokeMessage, ToolParameter, ) from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiBasedToolProviderController -from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.model_tool_provider import ModelToolProviderController -from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.utils.configuration import ( - ModelToolConfigurationManager, ToolConfigurationManager, ToolParameterConfigurationManager, ) @@ -42,68 +38,31 @@ from services.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) -_builtin_providers = {} -_builtin_tools_labels = {} - - class ToolManager: - @staticmethod - def invoke( - provider: str, - tool_id: str, - tool_name: str, - tool_parameters: dict[str, Any], - credentials: dict[str, Any], - prompt_messages: list[PromptMessage], - ) -> list[ToolInvokeMessage]: - """ - invoke the assistant + _builtin_provider_lock = Lock() + _builtin_providers = {} + _builtin_providers_loaded = False + _builtin_tools_labels = {} - :param provider: the name of the provider - :param tool_id: the id of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool - :param credentials: the credentials of the tool - :param prompt_messages: the prompt messages that the tool can use - - :return: the messages that the tool wants to send to the user - """ - provider_entity: ToolProviderController = None - if provider == DEFAULT_PROVIDERS.API_BASED: - provider_entity = ApiBasedToolProviderController() - elif provider == DEFAULT_PROVIDERS.APP_BASED: - provider_entity = AppBasedToolProviderEntity() - - if provider_entity is None: - # fetch the provider from .provider.builtin - provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.{provider}', - script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'), - parent_type=ToolProviderController) - provider_entity = provider_class() - - return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages) - - @staticmethod - def get_builtin_provider(provider: str) -> BuiltinToolProviderController: - global _builtin_providers + @classmethod + def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: """ get the builtin provider :param provider: the name of the provider :return: the provider """ - if len(_builtin_providers) == 0: + if len(cls._builtin_providers) == 0: # init the builtin providers - ToolManager.list_builtin_providers() + cls.load_builtin_providers_cache() - if provider not in _builtin_providers: + if provider not in cls._builtin_providers: raise ToolProviderNotFoundError(f'builtin provider {provider} not found') - return _builtin_providers[provider] + return cls._builtin_providers[provider] - @staticmethod - def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool: + @classmethod + def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ get the builtin tool @@ -112,13 +71,13 @@ class ToolManager: :return: the provider, the tool """ - provider_controller = ToolManager.get_builtin_provider(provider) + provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) return tool - @staticmethod - def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ + @classmethod + def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ -> Union[BuiltinTool, ApiTool]: """ get the tool @@ -130,19 +89,19 @@ class ToolManager: :return: the tool """ if provider_type == 'builtin': - return ToolManager.get_builtin_tool(provider_id, tool_name) + return cls.get_builtin_tool(provider_id, tool_name) elif provider_type == 'api': if tenant_id is None: raise ValueError('tenant id is required for api provider') - api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id) + api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id) return api_provider.get_tool(tool_name) elif provider_type == 'app': raise NotImplementedError('app provider not implemented') else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') - @staticmethod - def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \ + @classmethod + def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \ -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -154,10 +113,10 @@ class ToolManager: :return: the tool """ if provider_type == 'builtin': - builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name) + builtin_tool = cls.get_builtin_tool(provider_name, tool_name) # check if the builtin tool need credentials - provider_controller = ToolManager.get_builtin_provider(provider_name) + provider_controller = cls.get_builtin_provider(provider_name) if not provider_controller.need_credentials: return builtin_tool.fork_tool_runtime(meta={ 'tenant_id': tenant_id, @@ -175,7 +134,7 @@ class ToolManager: # decrypt the credentials credentials = builtin_provider.credentials - controller = ToolManager.get_builtin_provider(provider_name) + controller = cls.get_builtin_provider(provider_name) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) @@ -190,7 +149,7 @@ class ToolManager: if tenant_id is None: raise ValueError('tenant id is required for api provider') - api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name) # decrypt the credentials tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) @@ -204,7 +163,7 @@ class ToolManager: if tenant_id is None: raise ValueError('tenant id is required for model provider') # get model provider - model_provider = ToolManager.get_model_provider(tenant_id, provider_name) + model_provider = cls.get_model_provider(tenant_id, provider_name) # get tool model_tool = model_provider.get_tool(tool_name) @@ -218,8 +177,8 @@ class ToolManager: else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') - @staticmethod - def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + @classmethod + def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: """ init runtime parameter """ @@ -262,12 +221,12 @@ class ToolManager: return parameter_value - @staticmethod - def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity) -> Tool: + @classmethod + def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool: """ get the agent tool runtime """ - tool_entity = ToolManager.get_tool_runtime( + tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, @@ -277,7 +236,7 @@ class ToolManager: for parameter in parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM: # save tool parameter to tool entity memory - value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters) + value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters) runtime_parameters[parameter.name] = value # decrypt runtime parameters @@ -292,12 +251,12 @@ class ToolManager: tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity - @staticmethod - def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity): + @classmethod + def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity): """ get the workflow tool runtime """ - tool_entity = ToolManager.get_tool_runtime( + tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, provider_name=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, @@ -309,7 +268,7 @@ class ToolManager: for parameter in parameters: # save tool parameter to tool entity memory if parameter.form == ToolParameter.ToolParameterForm.FORM: - value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) + value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations) runtime_parameters[parameter.name] = value # decrypt runtime parameters @@ -326,8 +285,8 @@ class ToolManager: tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity - @staticmethod - def get_builtin_provider_icon(provider: str) -> tuple[str, str]: + @classmethod + def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ get the absolute path of the icon of the builtin provider @@ -336,7 +295,7 @@ class ToolManager: :return: the absolute path of the icon, the mime type of the icon """ # get provider - provider_controller = ToolManager.get_builtin_provider(provider) + provider_controller = cls.get_builtin_provider(provider) absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', provider_controller.identity.icon) @@ -350,15 +309,25 @@ class ToolManager: return absolute_path, mime_type - @staticmethod - def list_builtin_providers() -> Generator[BuiltinToolProviderController, None, None]: - global _builtin_providers - + @classmethod + def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: # use cache first - if len(_builtin_providers) > 0: - yield from list(_builtin_providers.values()) + if cls._builtin_providers_loaded: + yield from list(cls._builtin_providers.values()) return - + + with cls._builtin_provider_lock: + if cls._builtin_providers_loaded: + yield from list(cls._builtin_providers.values()) + return + + yield from cls._list_builtin_providers() + + @classmethod + def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: + """ + list all the builtin providers + """ for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): if provider.startswith('__'): continue @@ -375,52 +344,54 @@ class ToolManager: 'provider', 'builtin', provider, f'{provider}.py'), parent_type=BuiltinToolProviderController) provider: BuiltinToolProviderController = provider_class() - _builtin_providers[provider.identity.name] = provider + cls._builtin_providers[provider.identity.name] = provider for tool in provider.get_tools(): - _builtin_tools_labels[tool.identity.name] = tool.identity.label + cls._builtin_tools_labels[tool.identity.name] = tool.identity.label yield provider except Exception as e: logger.error(f'load builtin provider {provider} error: {e}') continue + # set builtin providers loaded + cls._builtin_providers_loaded = True - @staticmethod - def load_builtin_providers_cache(): - for _ in ToolManager.list_builtin_providers(): + @classmethod + def load_builtin_providers_cache(cls): + for _ in cls.list_builtin_providers(): pass - @staticmethod - def clear_builtin_providers_cache(): - global _builtin_providers - _builtin_providers = {} + @classmethod + def clear_builtin_providers_cache(cls): + cls._builtin_providers = {} + cls._builtin_providers_loaded = False - @staticmethod - def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]: - """ - list all the model providers + # @classmethod + # def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]: + # """ + # list all the model providers - :return: the list of the model providers - """ - tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' - # get configurations - model_configurations = ModelToolConfigurationManager.get_all_configuration() - # get all providers - provider_manager = ProviderManager() - configurations = provider_manager.get_configurations(tenant_id).values() - # get model providers - model_providers: list[ModelToolProviderController] = [] - for configuration in configurations: - # all the model tool should be configurated - if configuration.provider.provider not in model_configurations: - continue - if not ModelToolProviderController.is_configuration_valid(configuration): - continue - model_providers.append(ModelToolProviderController.from_db(configuration)) + # :return: the list of the model providers + # """ + # tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' + # # get configurations + # model_configurations = ModelToolConfigurationManager.get_all_configuration() + # # get all providers + # provider_manager = ProviderManager() + # configurations = provider_manager.get_configurations(tenant_id).values() + # # get model providers + # model_providers: list[ModelToolProviderController] = [] + # for configuration in configurations: + # # all the model tool should be configurated + # if configuration.provider.provider not in model_configurations: + # continue + # if not ModelToolProviderController.is_configuration_valid(configuration): + # continue + # model_providers.append(ModelToolProviderController.from_db(configuration)) - return model_providers + # return model_providers - @staticmethod - def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController: + @classmethod + def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController: """ get the model provider @@ -437,8 +408,8 @@ class ToolManager: return ModelToolProviderController.from_db(configuration) - @staticmethod - def get_tool_label(tool_name: str) -> Union[I18nObject, None]: + @classmethod + def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ get the tool label @@ -446,44 +417,44 @@ class ToolManager: :return: the label of the tool """ - global _builtin_tools_labels - if len(_builtin_tools_labels) == 0: + cls._builtin_tools_labels + if len(cls._builtin_tools_labels) == 0: # init the builtin providers - ToolManager.load_builtin_providers_cache() + cls.load_builtin_providers_cache() - if tool_name not in _builtin_tools_labels: + if tool_name not in cls._builtin_tools_labels: return None - return _builtin_tools_labels[tool_name] + return cls._builtin_tools_labels[tool_name] - @staticmethod - def user_list_providers( - user_id: str, - tenant_id: str, - ) -> list[UserToolProvider]: + @classmethod + def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} # get builtin providers - builtin_providers = ToolManager.list_builtin_providers() - + builtin_providers = cls.list_builtin_providers() + # get db builtin providers db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ filter(BuiltinToolProvider.tenant_id == tenant_id).all() - find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider), - None) + find_db_builtin_provider = lambda provider: next( + (x for x in db_builtin_providers if x.provider == provider), + None + ) # append builtin providers for provider in builtin_providers: user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), + decrypt_credentials=False ) result_providers[provider.identity.name] = user_provider # # get model tool providers - # model_providers = ToolManager.list_model_providers(tenant_id=tenant_id) + # model_providers = cls.list_model_providers(tenant_id=tenant_id) # # append model providers # for provider in model_providers: # user_provider = ToolTransformService.model_provider_to_user_provider( @@ -502,13 +473,14 @@ class ToolManager: user_provider = ToolTransformService.api_provider_to_user_provider( provider_controller=provider_controller, db_provider=db_api_provider, + decrypt_credentials=False ) result_providers[db_api_provider.name] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) - @staticmethod - def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ + @classmethod + def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ ApiBasedToolProviderController, dict[str, Any]]: """ get the api provider @@ -527,14 +499,15 @@ class ToolManager: controller = ApiBasedToolProviderController.from_db( provider, - ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else + ApiProviderAuthType.NONE ) controller.load_bundled_tools(provider.tools) return controller, provider.credentials - @staticmethod - def user_get_api_provider(provider: str, tenant_id: str) -> dict: + @classmethod + def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ get api provider """ @@ -582,8 +555,8 @@ class ToolManager: 'privacy_policy': provider.privacy_policy })) - @staticmethod - def get_tool_icon(tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: + @classmethod + def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: """ get the tool icon @@ -613,3 +586,5 @@ class ToolManager: } else: raise ValueError(f"provider type {provider_type} not found") + +ToolManager.load_builtin_providers_cache() \ No newline at end of file diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 927af1f5be..619e7ffd61 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy from typing import Any, Union from pydantic import BaseModel @@ -25,7 +26,7 @@ class ToolConfigurationManager(BaseModel): """ deep copy credentials """ - return {key: value for key, value in credentials.items()} + return deepcopy(credentials) def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]: """ diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index aed463e97f..29245f1f3b 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -354,6 +354,27 @@ class ToolManageService: return { 'result': 'success' } + @staticmethod + def get_builtin_tool_provider_credentials( + user_id: str, tenant_id: str, provider: str + ): + """ + get builtin tool provider credentials + """ + provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ).first() + + if provider is None: + raise ValueError(f'you have not added provider {provider}') + + provider_controller = ToolManager.get_builtin_provider(provider.provider) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) + credentials = tool_configuration.mask_tool_credentials(credentials) + return credentials + @staticmethod def update_api_tool_provider( user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, @@ -631,7 +652,8 @@ class ToolManageService: # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, - db_provider=find_provider(provider_controller.identity.name) + db_provider=find_provider(provider_controller.identity.name), + decrypt_credentials=True ) # add icon @@ -668,7 +690,8 @@ class ToolManageService: provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) user_provider = ToolTransformService.api_provider_to_user_provider( provider_controller, - db_provider=provider + db_provider=provider, + decrypt_credentials=True ) # add icon diff --git a/api/services/tools_transform_service.py b/api/services/tools_transform_service.py index 861eab73c5..d54b1e0437 100644 --- a/api/services/tools_transform_service.py +++ b/api/services/tools_transform_service.py @@ -64,6 +64,7 @@ class ToolTransformService: def builtin_provider_to_user_provider( provider_controller: BuiltinToolProviderController, db_provider: Optional[BuiltinToolProvider], + decrypt_credentials: bool = True ) -> UserToolProvider: """ convert provider controller to user provider @@ -100,19 +101,20 @@ class ToolTransformService: elif db_provider: result.is_team_authorization = True - credentials = db_provider.credentials + if decrypt_credentials: + credentials = db_provider.credentials - # init tool configuration - tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller - ) - # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) + # init tool configuration + tool_configuration = ToolConfigurationManager( + tenant_id=db_provider.tenant_id, + provider_controller=provider_controller + ) + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) - result.masked_credentials = masked_credentials - result.original_credentials = decrypted_credentials + result.masked_credentials = masked_credentials + result.original_credentials = decrypted_credentials return result @@ -126,7 +128,8 @@ class ToolTransformService: # package tool provider controller controller = ApiBasedToolProviderController.from_db( db_provider=db_provider, - auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else + ApiProviderAuthType.NONE ) return controller @@ -135,6 +138,7 @@ class ToolTransformService: def api_provider_to_user_provider( provider_controller: ApiBasedToolProviderController, db_provider: ApiToolProvider, + decrypt_credentials: bool = True ) -> UserToolProvider: """ convert provider controller to user provider @@ -165,17 +169,18 @@ class ToolTransformService: tools=[] ) - # init tool configuration - tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller - ) + if decrypt_credentials: + # init tool configuration + tool_configuration = ToolConfigurationManager( + tenant_id=db_provider.tenant_id, + provider_controller=provider_controller + ) - # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) - result.masked_credentials = masked_credentials + result.masked_credentials = masked_credentials return result