From 260fef40c475f447e49f3fdec92f9e165ae20519 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 21 Mar 2024 15:38:53 +0800 Subject: [PATCH] enhance: full tools --- .../console/workspace/tool_providers.py | 42 +++- api/core/tools/entities/user_entities.py | 25 ++- api/core/tools/tool_manager.py | 126 ++--------- api/services/tools_manage_service.py | 155 +++++++++----- api/services/tools_transform_service.py | 199 ++++++++++++++++++ 5 files changed, 369 insertions(+), 178 deletions(-) create mode 100644 api/services/tools_transform_service.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 931979c7f3..a44105537c 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -8,6 +8,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from services.tools_manage_service import ToolManageService @@ -30,11 +31,11 @@ class ToolBuiltinProviderListToolsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return ToolManageService.list_builtin_tool_provider_tools( + return jsonable_encoder(ToolManageService.list_builtin_tool_provider_tools( user_id, tenant_id, provider, - ) + )) class ToolBuiltinProviderDeleteApi(Resource): @setup_required @@ -101,11 +102,11 @@ class ToolModelProviderListToolsApi(Resource): args = parser.parse_args() - return ToolManageService.list_model_tool_provider_tools( + return jsonable_encoder(ToolManageService.list_model_tool_provider_tools( user_id, tenant_id, args['provider'], - ) + )) class ToolApiProviderAddApi(Resource): @setup_required @@ -170,11 +171,11 @@ class ToolApiProviderListToolsApi(Resource): args = parser.parse_args() - return ToolManageService.list_api_tool_provider_tools( + return jsonable_encoder(ToolManageService.list_api_tool_provider_tools( user_id, tenant_id, args['provider'], - ) + )) class ToolApiProviderUpdateApi(Resource): @setup_required @@ -301,6 +302,32 @@ class ToolApiProviderPreviousTestApi(Resource): args['schema'], ) +class ToolBuiltinListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_builtin_tools( + user_id, + tenant_id, + )]) + +class ToolApiListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_api_tools( + user_id, + tenant_id, + )]) + api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') @@ -317,3 +344,6 @@ api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/ap api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') + +api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') +api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') \ No newline at end of file diff --git a/api/core/tools/entities/user_entities.py b/api/core/tools/entities/user_entities.py index 8a5589da27..171bf831e2 100644 --- a/api/core/tools/entities/user_entities.py +++ b/api/core/tools/entities/user_entities.py @@ -8,6 +8,13 @@ from core.tools.entities.tool_entities import ToolProviderCredentials from core.tools.tool.tool import ToolParameter +class UserTool(BaseModel): + author: str + name: str # identifier + label: I18nObject # label + description: I18nObject + parameters: Optional[list[ToolParameter]] + class UserToolProvider(BaseModel): class ProviderType(Enum): BUILTIN = "builtin" @@ -22,9 +29,11 @@ class UserToolProvider(BaseModel): icon: str label: I18nObject # label type: ProviderType - team_credentials: dict = None + masked_credentials: dict = None + original_credentials: dict = None is_team_authorization: bool = False allow_delete: bool = True + tools: list[UserTool] = None def to_dict(self) -> dict: return { @@ -35,17 +44,11 @@ class UserToolProvider(BaseModel): 'icon': self.icon, 'label': self.label.to_dict(), 'type': self.type.value, - 'team_credentials': self.team_credentials, + 'team_credentials': self.masked_credentials, 'is_team_authorization': self.is_team_authorization, - 'allow_delete': self.allow_delete + 'allow_delete': self.allow_delete, + 'tools': self.tools } class UserToolProviderCredentials(BaseModel): - credentials: dict[str, ToolProviderCredentials] - -class UserTool(BaseModel): - author: str - name: str # identifier - label: I18nObject # label - description: I18nObject - parameters: Optional[list[ToolParameter]] \ No newline at end of file + credentials: dict[str, ToolProviderCredentials] \ No newline at end of file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d8b570fc30..632707815f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -15,7 +15,6 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolInvokeMessage, ToolParameter, - ToolProviderCredentials, ) from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolProviderNotFoundError @@ -37,6 +36,7 @@ from core.tools.utils.encoder import serialize_base_model_dict from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider +from services.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -468,131 +468,47 @@ class ToolManager: tenant_id: str, ) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} + # get builtin providers builtin_providers = ToolManager.list_builtin_providers() - # append builtin providers - for provider in builtin_providers: - result_providers[provider.identity.name] = UserToolProvider( - id=provider.identity.name, - author=provider.identity.author, - name=provider.identity.name, - description=I18nObject( - en_US=provider.identity.description.en_US, - zh_Hans=provider.identity.description.zh_Hans, - ), - icon=provider.identity.icon, - label=I18nObject( - en_US=provider.identity.label.en_US, - zh_Hans=provider.identity.label.zh_Hans, - ), - type=UserToolProvider.ProviderType.BUILTIN, - team_credentials={}, - is_team_authorization=False, - ) - - # get credentials schema - schema = provider.get_credentials_schema() - for name, value in schema.items(): - result_providers[provider.identity.name].team_credentials[name] = \ - ToolProviderCredentials.CredentialsType.default(value.type) - - # check if the provider need credentials - if not provider.need_credentials: - result_providers[provider.identity.name].is_team_authorization = True - result_providers[provider.identity.name].allow_delete = False - + # get db builtin providers db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ filter(BuiltinToolProvider.tenant_id == tenant_id).all() - for db_builtin_provider in db_builtin_providers: - # add provider into providers - credentials = db_builtin_provider.credentials - provider_name = db_builtin_provider.provider - if provider_name not in result_providers: - # the provider has been deleted - continue + find_db_builtin_provider = lambda provider: next((x for x in db_builtin_providers if x.provider == provider), None) + + # append builtin providers + for provider in builtin_providers: + user_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider, + db_provider=find_db_builtin_provider(provider.identity.name), + ) - result_providers[provider_name].is_team_authorization = True - - # package builtin tool provider controller - controller = ToolManager.get_builtin_provider(provider_name) - - # init tool configuration - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) - # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) - - result_providers[provider_name].team_credentials = masked_credentials + result_providers[provider.identity.name] = user_provider # get model tool providers model_providers = ToolManager.list_model_providers(tenant_id=tenant_id) # append model providers for provider in model_providers: - result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider( - id=provider.identity.name, - author=provider.identity.author, - name=provider.identity.name, - description=I18nObject( - en_US=provider.identity.description.en_US, - zh_Hans=provider.identity.description.zh_Hans, - ), - icon=provider.identity.icon, - label=I18nObject( - en_US=provider.identity.label.en_US, - zh_Hans=provider.identity.label.zh_Hans, - ), - type=UserToolProvider.ProviderType.MODEL, - team_credentials={}, - is_team_authorization=provider.is_active, + user_provider = ToolTransformService.model_provider_to_user_provider( + db_provider=provider, ) + result_providers[f'model_provider.{provider.identity.name}'] = user_provider # get db api providers db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ filter(ApiToolProvider.tenant_id == tenant_id).all() for db_api_provider in db_api_providers: - username = 'Anonymous' - try: - username = db_api_provider.user.name - except Exception as e: - logger.error(f'failed to get user name for api provider {db_api_provider.id}: {str(e)}') - # add provider into providers - credentials = db_api_provider.credentials - provider_name = db_api_provider.name - result_providers[provider_name] = UserToolProvider( - id=db_api_provider.id, - author=username, - name=db_api_provider.name, - description=I18nObject( - en_US=db_api_provider.description, - zh_Hans=db_api_provider.description, - ), - icon=db_api_provider.icon, - label=I18nObject( - en_US=db_api_provider.name, - zh_Hans=db_api_provider.name, - ), - type=UserToolProvider.ProviderType.API, - team_credentials={}, - is_team_authorization=True, - ) - - # package tool provider controller - controller = ApiBasedToolProviderController.from_db( + provider_controller = ToolTransformService.api_provider_to_controller( db_provider=db_api_provider, - auth_type=ApiProviderAuthType.API_KEY if db_api_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) - - # init tool configuration - tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) - - # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) - - result_providers[provider_name].team_credentials = masked_credentials + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=db_api_provider, + ) + result_providers[db_api_provider.name] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index 70c6a44459..fa033b39f3 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -9,7 +9,6 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ToolCredentialsOption, - ToolParameter, ToolProviderCredentials, ) from core.tools.entities.user_entities import UserTool, UserToolProvider @@ -23,6 +22,7 @@ from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider from services.model_provider_service import ModelProviderService +from services.tools_transform_service import ToolTransformService class ToolManageService: @@ -70,7 +70,7 @@ class ToolManageService: @staticmethod def list_builtin_tool_provider_tools( user_id: str, tenant_id: str, provider: str - ): + ) -> list[UserTool]: """ list builtin tool provider tools """ @@ -92,41 +92,11 @@ class ToolManageService: result = [] for tool in tools: - # fork tool runtime - tool = tool.fork_tool_runtime(meta={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + result.append(ToolTransformService.tool_to_user_tool( + tool=tool, credentials=credentials, tenant_id=tenant_id + )) - # get tool parameters - parameters = tool.parameters or [] - # get tool runtime parameters - runtime_parameters = tool.get_runtime_parameters() - # override parameters - current_parameters = parameters.copy() - for runtime_parameter in runtime_parameters: - found = False - for index, parameter in enumerate(current_parameters): - if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: - current_parameters[index] = runtime_parameter - found = True - break - - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: - current_parameters.append(runtime_parameter) - - user_tool = UserTool( - author=tool.identity.author, - name=tool.identity.name, - label=tool.identity.label, - description=tool.description.human, - parameters=current_parameters - ) - result.append(user_tool) - - return json.loads( - serialize_base_model_array(result) - ) + return result @staticmethod def list_builtin_provider_credentials_schema( @@ -318,7 +288,7 @@ class ToolManageService: @staticmethod def list_api_tool_provider_tools( user_id: str, tenant_id: str, provider: str - ): + ) -> list[UserTool]: """ list api tool provider tools """ @@ -330,23 +300,21 @@ class ToolManageService: if provider is None: raise ValueError(f'you have not added provider {provider}') - return json.loads( - serialize_base_model_array([ - UserTool( - author=tool_bundle.author, - name=tool_bundle.operation_id, - label=I18nObject( - en_US=tool_bundle.operation_id, - zh_Hans=tool_bundle.operation_id - ), - description=I18nObject( - en_US=tool_bundle.summary or '', - zh_Hans=tool_bundle.summary or '' - ), - parameters=tool_bundle.parameters - ) for tool_bundle in provider.tools - ]) - ) + return [ + UserTool( + author=tool_bundle.author, + name=tool_bundle.operation_id, + label=I18nObject( + en_US=tool_bundle.operation_id, + zh_Hans=tool_bundle.operation_id + ), + description=I18nObject( + en_US=tool_bundle.summary or '', + zh_Hans=tool_bundle.summary or '' + ), + parameters=tool_bundle.parameters + ) for tool_bundle in provider.tools + ] @staticmethod def update_builtin_tool_provider( @@ -527,7 +495,7 @@ class ToolManageService: @staticmethod def list_model_tool_provider_tools( user_id: str, tenant_id: str, provider: str - ): + ) -> list[UserTool]: """ list model tool provider tools """ @@ -655,4 +623,79 @@ class ToolManageService: except Exception as e: return { 'error': str(e) } - return { 'result': result or 'empty response' } \ No newline at end of file + return { 'result': result or 'empty response' } + + @staticmethod + def list_builtin_tools( + user_id: str, tenant_id: str + ) -> list[UserToolProvider]: + """ + list builtin tools + """ + # get all builtin providers + provider_controllers = ToolManager.list_builtin_providers() + + # get all user added providers + db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id + ).all() or [] + + # find provider + find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + + result: list[UserToolProvider] = [] + + for provider_controller in provider_controllers: + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.identity.name) + ) + + tools = provider_controller.get_tools() + for tool in tools: + user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + )) + + result.append(user_builtin_provider) + + return result + + @staticmethod + def list_api_tools( + user_id: str, tenant_id: str + ) -> list[UserToolProvider]: + """ + list api tools + """ + # get all api providers + db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( + ApiToolProvider.tenant_id == tenant_id + ).all() or [] + + result: list[UserToolProvider] = [] + + for provider in db_providers: + # convert provider controller to user provider + provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) + user_provider = ToolTransformService.api_provider_to_user_provider( + provider_controller, + db_provider=provider + ) + + tools = provider_controller.get_tools( + user_id=user_id, tenant_id=tenant_id + ) + for tool in tools: + user_provider.tools.append(ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_provider.original_credentials, + )) + + result.append(user_provider) + + return result \ No newline at end of file diff --git a/api/services/tools_transform_service.py b/api/services/tools_transform_service.py new file mode 100644 index 0000000000..8db7db62e4 --- /dev/null +++ b/api/services/tools_transform_service.py @@ -0,0 +1,199 @@ +import logging +from typing import Optional + +from core.model_runtime.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderCredentials +from core.tools.entities.user_entities import UserTool, UserToolProvider +from core.tools.provider.api_tool_provider import ApiBasedToolProviderController +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.model_tool_provider import ModelToolProviderController +from core.tools.tool.tool import Tool +from core.tools.utils.configuration import ToolConfigurationManager +from models.tools import ApiToolProvider, BuiltinToolProvider + +logger = logging.getLogger(__name__) + +class ToolTransformService: + @staticmethod + def builtin_provider_to_user_provider( + provider_controller: BuiltinToolProviderController, + db_provider: Optional[BuiltinToolProvider], + ) -> UserToolProvider: + """ + convert provider controller to user provider + """ + result = UserToolProvider( + id=provider_controller.identity.name, + author=provider_controller.identity.author, + name=provider_controller.identity.name, + description=I18nObject( + en_US=provider_controller.identity.description.en_US, + zh_Hans=provider_controller.identity.description.zh_Hans, + ), + icon=provider_controller.identity.icon, + label=I18nObject( + en_US=provider_controller.identity.label.en_US, + zh_Hans=provider_controller.identity.label.zh_Hans, + ), + type=UserToolProvider.ProviderType.BUILTIN, + masked_credentials={}, + is_team_authorization=False, + tools=[] + ) + + # get credentials schema + schema = provider_controller.get_credentials_schema() + for name, value in schema.items(): + result.masked_credentials[name] = \ + ToolProviderCredentials.CredentialsType.default(value.type) + + # check if the provider need credentials + if not provider_controller.need_credentials: + result.is_team_authorization = True + result.allow_delete = False + elif db_provider: + result.is_team_authorization = True + + credentials = db_provider.credentials + + # init tool configuration + tool_configuration = ToolConfigurationManager( + tenant_id=db_provider.tenant_id, + provider_controller=provider_controller + ) + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) + + result.masked_credentials = masked_credentials + result.original_credentials = decrypted_credentials + + return result + + @staticmethod + def api_provider_to_controller( + db_provider: ApiToolProvider, + ) -> ApiBasedToolProviderController: + """ + convert provider controller to user provider + """ + # package tool provider controller + controller = ApiBasedToolProviderController.from_db( + db_provider=db_provider, + auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + ) + + return controller + + @staticmethod + def api_provider_to_user_provider( + provider_controller: ApiBasedToolProviderController, + db_provider: ApiToolProvider, + ) -> UserToolProvider: + """ + convert provider controller to user provider + """ + username = 'Anonymous' + try: + username = db_provider.user.name + except Exception as e: + logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}') + # add provider into providers + credentials = db_provider.credentials + result = UserToolProvider( + id=db_provider.id, + author=username, + name=db_provider.name, + description=I18nObject( + en_US=db_provider.description, + zh_Hans=db_provider.description, + ), + icon=db_provider.icon, + label=I18nObject( + en_US=db_provider.name, + zh_Hans=db_provider.name, + ), + type=UserToolProvider.ProviderType.API, + masked_credentials={}, + is_team_authorization=True, + tools=[] + ) + + # init tool configuration + tool_configuration = ToolConfigurationManager( + tenant_id=db_provider.tenant_id, + provider_controller=provider_controller + ) + + # decrypt the credentials and mask the credentials + decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) + masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) + + result.masked_credentials = masked_credentials + + return result + + @staticmethod + def model_provider_to_user_provider( + db_provider: ModelToolProviderController, + ) -> UserToolProvider: + """ + convert provider controller to user provider + """ + return UserToolProvider( + id=db_provider.identity.name, + author=db_provider.identity.author, + name=db_provider.identity.name, + description=I18nObject( + en_US=db_provider.identity.description.en_US, + zh_Hans=db_provider.identity.description.zh_Hans, + ), + icon=db_provider.identity.icon, + label=I18nObject( + en_US=db_provider.identity.label.en_US, + zh_Hans=db_provider.identity.label.zh_Hans, + ), + type=UserToolProvider.ProviderType.MODEL, + masked_credentials={}, + is_team_authorization=db_provider.is_active, + ) + + @staticmethod + def tool_to_user_tool( + tool: Tool, credentials: dict = None, tenant_id: str = None + ) -> UserTool: + """ + convert tool to user tool + """ + # fork tool runtime + tool = tool.fork_tool_runtime(meta={ + 'credentials': credentials, + 'tenant_id': tenant_id, + }) + + # get tool parameters + parameters = tool.parameters or [] + # get tool runtime parameters + runtime_parameters = tool.get_runtime_parameters() or [] + # override parameters + current_parameters = parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + user_tool = UserTool( + author=tool.identity.author, + name=tool.identity.name, + label=tool.identity.label, + description=tool.description.human, + parameters=current_parameters + ) + + return user_tool \ No newline at end of file