optimize: tool

This commit is contained in:
Yeuoly 2024-04-02 16:58:24 +08:00
parent 36c3774fac
commit e46c3a9235
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
6 changed files with 192 additions and 172 deletions

View File

@ -76,13 +76,27 @@ class ToolBuiltinProviderUpdateApi(Resource):
provider, provider,
args['credentials'], args['credentials'],
) )
class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return ToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
)
class ToolBuiltinProviderIconApi(Resource): class ToolBuiltinProviderIconApi(Resource):
@setup_required @setup_required
def get(self, provider): def get(self, provider):
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider) icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE')) icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE'))
return send_file(io.BytesIO(icon_bytes), mimetype=minetype, max_age=icon_cache_max_age) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class ToolModelProviderIconApi(Resource): class ToolModelProviderIconApi(Resource):
@setup_required @setup_required
@ -333,6 +347,7 @@ api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools') api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete') api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update') api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon') api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
@ -340,7 +355,7 @@ api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provid
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update')
api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete')
api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get')
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')

View File

@ -127,7 +127,8 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials :return: whether the provider needs credentials
""" """
return self.credentials_schema is not None and len(self.credentials_schema) != 0 return self.credentials_schema is not None and \
len(self.credentials_schema) != 0
@property @property
def app_type(self) -> ToolProviderType: def app_type(self) -> ToolProviderType:

View File

@ -3,33 +3,29 @@ import logging
import mimetypes import mimetypes
from collections.abc import Generator from collections.abc import Generator
from os import listdir, path from os import listdir, path
from threading import Lock
from typing import Any, Union from typing import Any, Union
from flask import current_app from flask import current_app
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.tools import *
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constant import DEFAULT_PROVIDERS
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ToolInvokeMessage,
ToolParameter, ToolParameter,
) )
from core.tools.entities.user_entities import UserToolProvider from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.model_tool_provider import ModelToolProviderController from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.tool.tool import Tool from core.tools.tool.tool import Tool
from core.tools.utils.configuration import ( from core.tools.utils.configuration import (
ModelToolConfigurationManager,
ToolConfigurationManager, ToolConfigurationManager,
ToolParameterConfigurationManager, ToolParameterConfigurationManager,
) )
@ -42,68 +38,31 @@ from services.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_builtin_providers = {}
_builtin_tools_labels = {}
class ToolManager: class ToolManager:
@staticmethod _builtin_provider_lock = Lock()
def invoke( _builtin_providers = {}
provider: str, _builtin_providers_loaded = False
tool_id: str, _builtin_tools_labels = {}
tool_name: str,
tool_parameters: dict[str, Any],
credentials: dict[str, Any],
prompt_messages: list[PromptMessage],
) -> list[ToolInvokeMessage]:
"""
invoke the assistant
:param provider: the name of the provider @classmethod
:param tool_id: the id of the tool def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
:param tool_name: the name of the tool, defined in `get_tools`
:param tool_parameters: the parameters of the tool
:param credentials: the credentials of the tool
:param prompt_messages: the prompt messages that the tool can use
:return: the messages that the tool wants to send to the user
"""
provider_entity: ToolProviderController = None
if provider == DEFAULT_PROVIDERS.API_BASED:
provider_entity = ApiBasedToolProviderController()
elif provider == DEFAULT_PROVIDERS.APP_BASED:
provider_entity = AppBasedToolProviderEntity()
if provider_entity is None:
# fetch the provider from .provider.builtin
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
parent_type=ToolProviderController)
provider_entity = provider_class()
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
@staticmethod
def get_builtin_provider(provider: str) -> BuiltinToolProviderController:
global _builtin_providers
""" """
get the builtin provider get the builtin provider
:param provider: the name of the provider :param provider: the name of the provider
:return: the provider :return: the provider
""" """
if len(_builtin_providers) == 0: if len(cls._builtin_providers) == 0:
# init the builtin providers # init the builtin providers
ToolManager.list_builtin_providers() cls.load_builtin_providers_cache()
if provider not in _builtin_providers: if provider not in cls._builtin_providers:
raise ToolProviderNotFoundError(f'builtin provider {provider} not found') raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
return _builtin_providers[provider] return cls._builtin_providers[provider]
@staticmethod @classmethod
def get_builtin_tool(provider: str, tool_name: str) -> BuiltinTool: def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
""" """
get the builtin tool get the builtin tool
@ -112,13 +71,13 @@ class ToolManager:
:return: the provider, the tool :return: the provider, the tool
""" """
provider_controller = ToolManager.get_builtin_provider(provider) provider_controller = cls.get_builtin_provider(provider)
tool = provider_controller.get_tool(tool_name) tool = provider_controller.get_tool(tool_name)
return tool return tool
@staticmethod @classmethod
def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool]:
""" """
get the tool get the tool
@ -130,19 +89,19 @@ class ToolManager:
:return: the tool :return: the tool
""" """
if provider_type == 'builtin': if provider_type == 'builtin':
return ToolManager.get_builtin_tool(provider_id, tool_name) return cls.get_builtin_tool(provider_id, tool_name)
elif provider_type == 'api': elif provider_type == 'api':
if tenant_id is None: if tenant_id is None:
raise ValueError('tenant id is required for api provider') raise ValueError('tenant id is required for api provider')
api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id) api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
return api_provider.get_tool(tool_name) return api_provider.get_tool(tool_name)
elif provider_type == 'app': elif provider_type == 'app':
raise NotImplementedError('app provider not implemented') raise NotImplementedError('app provider not implemented')
else: else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found') raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod @classmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \ def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
-> Union[BuiltinTool, ApiTool]: -> Union[BuiltinTool, ApiTool]:
""" """
get the tool runtime get the tool runtime
@ -154,10 +113,10 @@ class ToolManager:
:return: the tool :return: the tool
""" """
if provider_type == 'builtin': if provider_type == 'builtin':
builtin_tool = ToolManager.get_builtin_tool(provider_name, tool_name) builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
# check if the builtin tool need credentials # check if the builtin tool need credentials
provider_controller = ToolManager.get_builtin_provider(provider_name) provider_controller = cls.get_builtin_provider(provider_name)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(meta={ return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tenant_id, 'tenant_id': tenant_id,
@ -175,7 +134,7 @@ class ToolManager:
# decrypt the credentials # decrypt the credentials
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name) controller = cls.get_builtin_provider(provider_name)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
@ -190,7 +149,7 @@ class ToolManager:
if tenant_id is None: if tenant_id is None:
raise ValueError('tenant id is required for api provider') raise ValueError('tenant id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials # decrypt the credentials
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
@ -204,7 +163,7 @@ class ToolManager:
if tenant_id is None: if tenant_id is None:
raise ValueError('tenant id is required for model provider') raise ValueError('tenant id is required for model provider')
# get model provider # get model provider
model_provider = ToolManager.get_model_provider(tenant_id, provider_name) model_provider = cls.get_model_provider(tenant_id, provider_name)
# get tool # get tool
model_tool = model_provider.get_tool(tool_name) model_tool = model_provider.get_tool(tool_name)
@ -218,8 +177,8 @@ class ToolManager:
else: else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found') raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod @classmethod
def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
""" """
init runtime parameter init runtime parameter
""" """
@ -262,12 +221,12 @@ class ToolManager:
return parameter_value return parameter_value
@staticmethod @classmethod
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity) -> Tool: def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
""" """
get the agent tool runtime get the agent tool runtime
""" """
tool_entity = ToolManager.get_tool_runtime( tool_entity = cls.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
tool_name=agent_tool.tool_name, tool_name=agent_tool.tool_name,
tenant_id=tenant_id, tenant_id=tenant_id,
@ -277,7 +236,7 @@ class ToolManager:
for parameter in parameters: for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM: if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory # save tool parameter to tool entity memory
value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters) value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
runtime_parameters[parameter.name] = value runtime_parameters[parameter.name] = value
# decrypt runtime parameters # decrypt runtime parameters
@ -292,12 +251,12 @@ class ToolManager:
tool_entity.runtime.runtime_parameters.update(runtime_parameters) tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity return tool_entity
@staticmethod @classmethod
def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity): def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
""" """
get the workflow tool runtime get the workflow tool runtime
""" """
tool_entity = ToolManager.get_tool_runtime( tool_entity = cls.get_tool_runtime(
provider_type=workflow_tool.provider_type, provider_type=workflow_tool.provider_type,
provider_name=workflow_tool.provider_id, provider_name=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name, tool_name=workflow_tool.tool_name,
@ -309,7 +268,7 @@ class ToolManager:
for parameter in parameters: for parameter in parameters:
# save tool parameter to tool entity memory # save tool parameter to tool entity memory
if parameter.form == ToolParameter.ToolParameterForm.FORM: if parameter.form == ToolParameter.ToolParameterForm.FORM:
value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
runtime_parameters[parameter.name] = value runtime_parameters[parameter.name] = value
# decrypt runtime parameters # decrypt runtime parameters
@ -326,8 +285,8 @@ class ToolManager:
tool_entity.runtime.runtime_parameters.update(runtime_parameters) tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity return tool_entity
@staticmethod @classmethod
def get_builtin_provider_icon(provider: str) -> tuple[str, str]: def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
""" """
get the absolute path of the icon of the builtin provider get the absolute path of the icon of the builtin provider
@ -336,7 +295,7 @@ class ToolManager:
:return: the absolute path of the icon, the mime type of the icon :return: the absolute path of the icon, the mime type of the icon
""" """
# get provider # get provider
provider_controller = ToolManager.get_builtin_provider(provider) provider_controller = cls.get_builtin_provider(provider)
absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
provider_controller.identity.icon) provider_controller.identity.icon)
@ -350,15 +309,25 @@ class ToolManager:
return absolute_path, mime_type return absolute_path, mime_type
@staticmethod @classmethod
def list_builtin_providers() -> Generator[BuiltinToolProviderController, None, None]: def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
global _builtin_providers
# use cache first # use cache first
if len(_builtin_providers) > 0: if cls._builtin_providers_loaded:
yield from list(_builtin_providers.values()) yield from list(cls._builtin_providers.values())
return return
with cls._builtin_provider_lock:
if cls._builtin_providers_loaded:
yield from list(cls._builtin_providers.values())
return
yield from cls._list_builtin_providers()
@classmethod
def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
"""
list all the builtin providers
"""
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'): if provider.startswith('__'):
continue continue
@ -375,52 +344,54 @@ class ToolManager:
'provider', 'builtin', provider, f'{provider}.py'), 'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController) parent_type=BuiltinToolProviderController)
provider: BuiltinToolProviderController = provider_class() provider: BuiltinToolProviderController = provider_class()
_builtin_providers[provider.identity.name] = provider cls._builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools(): for tool in provider.get_tools():
_builtin_tools_labels[tool.identity.name] = tool.identity.label cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
yield provider yield provider
except Exception as e: except Exception as e:
logger.error(f'load builtin provider {provider} error: {e}') logger.error(f'load builtin provider {provider} error: {e}')
continue continue
# set builtin providers loaded
cls._builtin_providers_loaded = True
@staticmethod @classmethod
def load_builtin_providers_cache(): def load_builtin_providers_cache(cls):
for _ in ToolManager.list_builtin_providers(): for _ in cls.list_builtin_providers():
pass pass
@staticmethod @classmethod
def clear_builtin_providers_cache(): def clear_builtin_providers_cache(cls):
global _builtin_providers cls._builtin_providers = {}
_builtin_providers = {} cls._builtin_providers_loaded = False
@staticmethod # @classmethod
def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]: # def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]:
""" # """
list all the model providers # list all the model providers
:return: the list of the model providers # :return: the list of the model providers
""" # """
tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' # tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
# get configurations # # get configurations
model_configurations = ModelToolConfigurationManager.get_all_configuration() # model_configurations = ModelToolConfigurationManager.get_all_configuration()
# get all providers # # get all providers
provider_manager = ProviderManager() # provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id).values() # configurations = provider_manager.get_configurations(tenant_id).values()
# get model providers # # get model providers
model_providers: list[ModelToolProviderController] = [] # model_providers: list[ModelToolProviderController] = []
for configuration in configurations: # for configuration in configurations:
# all the model tool should be configurated # # all the model tool should be configurated
if configuration.provider.provider not in model_configurations: # if configuration.provider.provider not in model_configurations:
continue # continue
if not ModelToolProviderController.is_configuration_valid(configuration): # if not ModelToolProviderController.is_configuration_valid(configuration):
continue # continue
model_providers.append(ModelToolProviderController.from_db(configuration)) # model_providers.append(ModelToolProviderController.from_db(configuration))
return model_providers # return model_providers
@staticmethod @classmethod
def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController: def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController:
""" """
get the model provider get the model provider
@ -437,8 +408,8 @@ class ToolManager:
return ModelToolProviderController.from_db(configuration) return ModelToolProviderController.from_db(configuration)
@staticmethod @classmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]: def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
""" """
get the tool label get the tool label
@ -446,44 +417,44 @@ class ToolManager:
:return: the label of the tool :return: the label of the tool
""" """
global _builtin_tools_labels cls._builtin_tools_labels
if len(_builtin_tools_labels) == 0: if len(cls._builtin_tools_labels) == 0:
# init the builtin providers # init the builtin providers
ToolManager.load_builtin_providers_cache() cls.load_builtin_providers_cache()
if tool_name not in _builtin_tools_labels: if tool_name not in cls._builtin_tools_labels:
return None return None
return _builtin_tools_labels[tool_name] return cls._builtin_tools_labels[tool_name]
@staticmethod @classmethod
def user_list_providers( def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
user_id: str,
tenant_id: str,
) -> list[UserToolProvider]:
result_providers: dict[str, UserToolProvider] = {} result_providers: dict[str, UserToolProvider] = {}
# get builtin providers # get builtin providers
builtin_providers = ToolManager.list_builtin_providers() builtin_providers = cls.list_builtin_providers()
# get db builtin providers # get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
filter(BuiltinToolProvider.tenant_id == tenant_id).all() filter(BuiltinToolProvider.tenant_id == tenant_id).all()
find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider), find_db_builtin_provider = lambda provider: next(
None) (x for x in db_builtin_providers if x.provider == provider),
None
)
# append builtin providers # append builtin providers
for provider in builtin_providers: for provider in builtin_providers:
user_provider = ToolTransformService.builtin_provider_to_user_provider( user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider, provider_controller=provider,
db_provider=find_db_builtin_provider(provider.identity.name), db_provider=find_db_builtin_provider(provider.identity.name),
decrypt_credentials=False
) )
result_providers[provider.identity.name] = user_provider result_providers[provider.identity.name] = user_provider
# # get model tool providers # # get model tool providers
# model_providers = ToolManager.list_model_providers(tenant_id=tenant_id) # model_providers = cls.list_model_providers(tenant_id=tenant_id)
# # append model providers # # append model providers
# for provider in model_providers: # for provider in model_providers:
# user_provider = ToolTransformService.model_provider_to_user_provider( # user_provider = ToolTransformService.model_provider_to_user_provider(
@ -502,13 +473,14 @@ class ToolManager:
user_provider = ToolTransformService.api_provider_to_user_provider( user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller=provider_controller, provider_controller=provider_controller,
db_provider=db_api_provider, db_provider=db_api_provider,
decrypt_credentials=False
) )
result_providers[db_api_provider.name] = user_provider result_providers[db_api_provider.name] = user_provider
return BuiltinToolProviderSort.sort(list(result_providers.values())) return BuiltinToolProviderSort.sort(list(result_providers.values()))
@staticmethod @classmethod
def get_api_provider_controller(tenant_id: str, provider_id: str) -> tuple[ def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
ApiBasedToolProviderController, dict[str, Any]]: ApiBasedToolProviderController, dict[str, Any]]:
""" """
get the api provider get the api provider
@ -527,14 +499,15 @@ class ToolManager:
controller = ApiBasedToolProviderController.from_db( controller = ApiBasedToolProviderController.from_db(
provider, provider,
ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
) )
controller.load_bundled_tools(provider.tools) controller.load_bundled_tools(provider.tools)
return controller, provider.credentials return controller, provider.credentials
@staticmethod @classmethod
def user_get_api_provider(provider: str, tenant_id: str) -> dict: def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
""" """
get api provider get api provider
""" """
@ -582,8 +555,8 @@ class ToolManager:
'privacy_policy': provider.privacy_policy 'privacy_policy': provider.privacy_policy
})) }))
@staticmethod @classmethod
def get_tool_icon(tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
""" """
get the tool icon get the tool icon
@ -613,3 +586,5 @@ class ToolManager:
} }
else: else:
raise ValueError(f"provider type {provider_type} not found") raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache()

View File

@ -1,4 +1,5 @@
import os import os
from copy import deepcopy
from typing import Any, Union from typing import Any, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -25,7 +26,7 @@ class ToolConfigurationManager(BaseModel):
""" """
deep copy credentials deep copy credentials
""" """
return {key: value for key, value in credentials.items()} return deepcopy(credentials)
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]: def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
""" """

View File

@ -354,6 +354,27 @@ class ToolManageService:
return { 'result': 'success' } return { 'result': 'success' }
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
):
"""
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
if provider is None:
raise ValueError(f'you have not added provider {provider}')
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
@staticmethod @staticmethod
def update_api_tool_provider( def update_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
@ -631,7 +652,8 @@ class ToolManageService:
# convert provider controller to user provider # convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller, provider_controller=provider_controller,
db_provider=find_provider(provider_controller.identity.name) db_provider=find_provider(provider_controller.identity.name),
decrypt_credentials=True
) )
# add icon # add icon
@ -668,7 +690,8 @@ class ToolManageService:
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
user_provider = ToolTransformService.api_provider_to_user_provider( user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller, provider_controller,
db_provider=provider db_provider=provider,
decrypt_credentials=True
) )
# add icon # add icon

View File

@ -64,6 +64,7 @@ class ToolTransformService:
def builtin_provider_to_user_provider( def builtin_provider_to_user_provider(
provider_controller: BuiltinToolProviderController, provider_controller: BuiltinToolProviderController,
db_provider: Optional[BuiltinToolProvider], db_provider: Optional[BuiltinToolProvider],
decrypt_credentials: bool = True
) -> UserToolProvider: ) -> UserToolProvider:
""" """
convert provider controller to user provider convert provider controller to user provider
@ -100,19 +101,20 @@ class ToolTransformService:
elif db_provider: elif db_provider:
result.is_team_authorization = True result.is_team_authorization = True
credentials = db_provider.credentials if decrypt_credentials:
credentials = db_provider.credentials
# init tool configuration # init tool configuration
tool_configuration = ToolConfigurationManager( tool_configuration = ToolConfigurationManager(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
provider_controller=provider_controller provider_controller=provider_controller
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
result.masked_credentials = masked_credentials result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials result.original_credentials = decrypted_credentials
return result return result
@ -126,7 +128,8 @@ class ToolTransformService:
# package tool provider controller # package tool provider controller
controller = ApiBasedToolProviderController.from_db( controller = ApiBasedToolProviderController.from_db(
db_provider=db_provider, db_provider=db_provider,
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
) )
return controller return controller
@ -135,6 +138,7 @@ class ToolTransformService:
def api_provider_to_user_provider( def api_provider_to_user_provider(
provider_controller: ApiBasedToolProviderController, provider_controller: ApiBasedToolProviderController,
db_provider: ApiToolProvider, db_provider: ApiToolProvider,
decrypt_credentials: bool = True
) -> UserToolProvider: ) -> UserToolProvider:
""" """
convert provider controller to user provider convert provider controller to user provider
@ -165,17 +169,18 @@ class ToolTransformService:
tools=[] tools=[]
) )
# init tool configuration if decrypt_credentials:
tool_configuration = ToolConfigurationManager( # init tool configuration
tenant_id=db_provider.tenant_id, tool_configuration = ToolConfigurationManager(
provider_controller=provider_controller tenant_id=db_provider.tenant_id,
) provider_controller=provider_controller
)
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
result.masked_credentials = masked_credentials result.masked_credentials = masked_credentials
return result return result