From 51404f90355020b0cb7d1603a2f13b0ee38b7c27 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 28 Mar 2024 18:36:58 +0800 Subject: [PATCH] refactor: tool engine --- api/controllers/console/app/app.py | 1 - api/controllers/console/app/model_config.py | 2 - api/core/agent/base_agent_runner.py | 97 +-------- api/core/agent/cot_agent_runner.py | 97 ++++----- api/core/agent/fc_agent_runner.py | 104 +++------- api/core/tools/tool/tool.py | 52 +---- api/core/tools/tool_engine.py | 218 ++++++++++++++++++++ api/core/tools/tool_manager.py | 30 +-- 8 files changed, 318 insertions(+), 283 deletions(-) create mode 100644 api/core/tools/tool_engine.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9c8ebfac6c..9c362a9ed0 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -123,7 +123,6 @@ class AppApi(Resource): tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, agent_tool=agent_tool_entity, - agent_callback=None ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 41b7151ba6..a7eaee3460 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -58,7 +58,6 @@ class ModelConfigResource(Resource): tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, agent_tool=agent_tool_entity, - agent_callback=None ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, @@ -96,7 +95,6 @@ class ModelConfigResource(Resource): tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, agent_tool=agent_tool_entity, - agent_callback=None ) except Exception as e: continue diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index b93ac56916..eb3e5ac193 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -10,12 +10,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, - InvokeFrom, ModelConfigWithCredentialsEntity, ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage @@ -32,7 +30,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.tool_entities import ( ToolInvokeMessage, - ToolInvokeMessageBinary, ToolParameter, ToolRuntimeVariablePool, ) @@ -40,7 +37,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager from extensions.ext_database import db -from models.model import Message, MessageAgentThought, MessageFile +from models.model import Message, MessageAgentThought from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -156,7 +153,6 @@ class BaseAgentRunner(AppRunner): tool_entity = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, agent_tool=tool, - agent_callback=self.agent_callback ) tool_entity.load_variables(self.variables_pool) @@ -270,87 +266,6 @@ class BaseAgentRunner(AppRunner): prompt_tool.parameters['required'].append(parameter.name) return prompt_tool - - def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: - """ - Extract tool response binary - """ - result = [] - - for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.BLOB: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.LINK: - # check if there is a mime type in meta - if response.meta and 'mime_type' in response.meta: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', - url=response.message, - save_as=response.save_as, - )) - - return result - - def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]: - """ - Create message file - - :param messages: messages - :return: message files, should save as variable - """ - result = [] - - for message in messages: - file_type = 'bin' - if 'image' in message.mimetype: - file_type = 'image' - elif 'video' in message.mimetype: - file_type = 'video' - elif 'audio' in message.mimetype: - file_type = 'audio' - elif 'text' in message.mimetype: - file_type = 'text' - elif 'pdf' in message.mimetype: - file_type = 'pdf' - elif 'zip' in message.mimetype: - file_type = 'archive' - # ... - - invoke_from = self.application_generate_entity.invoke_from - - message_file = MessageFile( - message_id=self.message.id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE.value, - belongs_to='assistant', - url=message.url, - upload_file_id=None, - created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), - created_by=self.user_id, - ) - db.session.add(message_file) - db.session.commit() - db.session.refresh(message_file) - - result.append(( - message_file, - message.save_as - )) - - db.session.close() - - return result def create_agent_thought(self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] @@ -500,8 +415,12 @@ class BaseAgentRunner(AppRunner): try: tool_inputs = json.loads(agent_thought.tool_input) except Exception as e: - logging.warning("tool execution error: {}, tool_input: {}.".format(str(e), agent_thought.tool_input)) - tool_inputs = { agent_thought.tool: agent_thought.tool_input } + tool_inputs = { tool: {} for tool in tools } + try: + tool_responses = json.loads(agent_thought.observation) + except Exception as e: + tool_responses = { tool: agent_thought.observation for tool in tools } + for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) @@ -514,7 +433,7 @@ class BaseAgentRunner(AppRunner): ) )) tool_call_response.append(ToolPromptMessage( - content=agent_thought.observation, + content=tool_responses.get(tool, agent_thought.observation), name=tool, tool_call_id=tool_call_id, )) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 81e082909d..c8191552ee 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -17,15 +17,7 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.errors import ( - ToolInvokeError, - ToolNotFoundError, - ToolNotSupportedError, - ToolParameterValidationError, - ToolProviderCredentialValidationError, - ToolProviderNotFoundError, -) -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.tool_engine import ToolEngine from models.model import Conversation, Message @@ -267,60 +259,47 @@ class CotAgentRunner(BaseAgentRunner): agent_thought_id=agent_thought.id ), PublishFrom.APPLICATION_MANAGER) else: + if isinstance(tool_call_args, str): + try: + tool_call_args = json.loads(tool_call_args) + except json.JSONDecodeError: + pass + # invoke tool - error_response = None - try: - if isinstance(tool_call_args, str): - try: - tool_call_args = json.loads(tool_call_args) - except json.JSONDecodeError: - pass + tool_invoke_response, message_files = ToolEngine.agent_invoke( + tool=tool_instance, + tool_parameters=tool_call_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=self.message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback + ) + # publish files + for message_file, save_as in message_files: + if save_as: + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) - tool_response = tool_instance.invoke( - user_id=self.user_id, - tool_parameters=tool_call_args - ) - # transform tool response to llm friendly response - tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=tool_response, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id - ) - # extract binary data from tool invoke message - binary_files = self.extract_tool_response_binary(tool_response) - # create message file - message_files = self.create_message_files(binary_files) - # publish files - for message_file, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, - value=message_file.id, - name=save_as) - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), PublishFrom.APPLICATION_MANAGER) + # publish message file + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) + # add message file ids + message_file_ids.append(message_file.id) - message_file_ids = [message_file.id for message_file, _ in message_files] - except ToolProviderCredentialValidationError as e: - error_response = "Please check your tool provider credentials" - except ( - ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError - ) as e: - error_response = f"there is not a tool named {tool_call_name}" - except ( - ToolParameterValidationError - ) as e: - error_response = f"tool parameters validation error: {e}, please check your tool parameters" - except ToolInvokeError as e: - error_response = f"tool invoke error: {e}" - except Exception as e: - error_response = f"unknown error: {e}" + # publish files + for message_file, save_as in message_files: + if save_as: + self.variables_pool.set_file(tool_name=tool_call_name, + value=message_file.id, + name=save_as) + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) - if error_response: - observation = error_response - else: - observation = self._convert_tool_response_to_str(tool_response) + message_file_ids = [message_file.id for message_file, _ in message_files] + + observation = tool_invoke_response # save scratchpad scratchpad.observation = observation diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 185d7684c8..533aff46fd 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -15,15 +15,7 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.tools.errors import ( - ToolInvokeError, - ToolNotFoundError, - ToolNotSupportedError, - ToolParameterValidationError, - ToolProviderCredentialValidationError, - ToolProviderNotFoundError, -) -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.tool_engine import ToolEngine from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -261,70 +253,37 @@ class FunctionCallAgentRunner(BaseAgentRunner): "tool_call_name": tool_call_name, "tool_response": f"there is not a tool named {tool_call_name}" } - tool_responses.append(tool_response) else: # invoke tool - error_response = None - try: - tool_invoke_message = tool_instance.invoke( - user_id=self.user_id, - tool_parameters=tool_call_args, - ) - # transform tool invoke message to get LLM friendly message - tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=tool_invoke_message, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id - ) - # extract binary data from tool invoke message - binary_files = self.extract_tool_response_binary(tool_invoke_message) - # create message file - message_files = self.create_message_files(binary_files) - # publish files - for message_file, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) - - # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file.id - ), PublishFrom.APPLICATION_MANAGER) - # add message file ids - message_file_ids.append(message_file.id) - - except ToolProviderCredentialValidationError as e: - error_response = "Please check your tool provider credentials" - except ( - ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError - ) as e: - error_response = f"there is not a tool named {tool_call_name}" - except ( - ToolParameterValidationError - ) as e: - error_response = f"tool parameters validation error: {e}, please check your tool parameters" - except ToolInvokeError as e: - error_response = f"tool invoke error: {e}" - except Exception as e: - error_response = f"unknown error: {e}" - - if error_response: - observation = error_response - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": error_response - } - tool_responses.append(tool_response) - else: - observation = self._convert_tool_response_to_str(tool_invoke_message) - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": observation - } - tool_responses.append(tool_response) + tool_invoke_response, message_files = ToolEngine.agent_invoke( + tool=tool_instance, + tool_parameters=tool_call_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=self.message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback, + ) + # publish files + for message_file, save_as in message_files: + if save_as: + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) + # publish message file + self.queue_manager.publish(QueueMessageFileEvent( + message_file_id=message_file.id + ), PublishFrom.APPLICATION_MANAGER) + # add message file ids + message_file_ids.append(message_file.id) + + observation = tool_invoke_response + tool_response = { + "tool_call_id": tool_call_id, + "tool_call_name": tool_call_name, + "tool_response": observation + } + + tool_responses.append(tool_response) prompt_messages = self.organize_prompt_messages( prompt_template=prompt_template, query=None, @@ -341,7 +300,10 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_name=None, tool_input=None, thought=None, - observation=tool_response['tool_response'], + observation=json.dumps({ + tool_response['tool_call_name']: tool_response['tool_response'] + for tool_response in tool_responses + }), answer=None, messages_ids=message_file_ids ) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 103fb931c5..8f556cfa1a 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -4,7 +4,6 @@ from typing import Any, Optional, Union from pydantic import BaseModel -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.tools.entities.tool_entities import ( ToolDescription, ToolIdentity, @@ -22,8 +21,6 @@ class Tool(BaseModel, ABC): parameters: Optional[list[ToolParameter]] = None description: ToolDescription = None is_team_authorization: bool = False - agent_callback: Optional[DifyAgentCallbackHandler] = None - use_callback: bool = False class Runtime(BaseModel): """ @@ -45,15 +42,10 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - if not self.agent_callback: - self.use_callback = False - else: - self.use_callback = True - class VARIABLE_KEY(Enum): IMAGE = 'image' - def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool': + def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': """ fork a new tool with meta data @@ -65,7 +57,6 @@ class Tool(BaseModel, ABC): parameters=self.parameters.copy() if self.parameters else None, description=self.description.copy() if self.description else None, runtime=Tool.Runtime(**meta), - agent_callback=agent_callback ) def load_variables(self, variables: ToolRuntimeVariablePool): @@ -174,50 +165,19 @@ class Tool(BaseModel, ABC): return result - def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]: - # check if tool_parameters is a string - if isinstance(tool_parameters, str): - # check if this tool has only one parameter - parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM] - if parameters and len(parameters) == 1: - tool_parameters = { - parameters[0].name: tool_parameters - } - else: - raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") - + def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: # update tool_parameters if self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) - # hit callback - if self.use_callback: - self.agent_callback.on_tool_start( - tool_name=self.identity.name, - tool_inputs=tool_parameters - ) - - try: - result = self._invoke( - user_id=user_id, - tool_parameters=tool_parameters, - ) - except Exception as e: - if self.use_callback: - self.agent_callback.on_tool_error(e) - raise e + result = self._invoke( + user_id=user_id, + tool_parameters=tool_parameters, + ) if not isinstance(result, list): result = [result] - # hit callback - if self.use_callback: - self.agent_callback.on_tool_end( - tool_name=self.identity.name, - tool_inputs=tool_parameters, - tool_outputs=self._convert_tool_response_to_str(result) - ) - return result def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py new file mode 100644 index 0000000000..7d836b4695 --- /dev/null +++ b/api/core/tools/tool_engine.py @@ -0,0 +1,218 @@ +from typing import Union + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.file.file_obj import FileTransferMethod +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter +from core.tools.errors import ( + ToolInvokeError, + ToolNotFoundError, + ToolNotSupportedError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) +from core.tools.tool.tool import Tool +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from extensions.ext_database import db +from models.model import Message, MessageFile + + +class ToolEngine: + """ + Tool runtime engine take care of the tool executions. + """ + @staticmethod + def agent_invoke(tool: Tool, tool_parameters: Union[str, dict], + user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, + agent_tool_callback: DifyAgentCallbackHandler) \ + -> tuple[str, list[tuple[MessageFile, bool]]]: + """ + Agent invokes the tool with the given arguments. + """ + # check if arguments is a string + if isinstance(tool_parameters, str): + # check if this tool has only one parameter + parameters = [ + parameter for parameter in tool.parameters + if parameter.form == ToolParameter.ToolParameterForm.LLM + ] + if parameters and len(parameters) == 1: + tool_parameters = { + parameters[0].name: tool_parameters + } + else: + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + + # invoke the tool + try: + # hit the callback handler + agent_tool_callback.on_tool_start( + tool_name=tool.identity.name, + tool_inputs=tool_parameters + ) + + response = tool.invoke(user_id, tool_parameters) + + response = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=response, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=message.conversation_id + ) + + # extract binary data from tool invoke message + binary_files = ToolEngine._extract_tool_response_binary(response) + # create message file + message_files = ToolEngine._create_message_files( + tool_messages=binary_files, + agent_message=message, + invoke_from=invoke_from, + user_id=user_id + ) + + plain_text = ToolEngine._convert_tool_response_to_str(response) + + # hit the callback handler + agent_tool_callback.on_tool_end( + tool_name=tool.identity.name, + tool_inputs=tool_parameters, + tool_outputs=plain_text + ) + + # transform tool invoke message to get LLM friendly message + return plain_text, message_files + except ToolProviderCredentialValidationError as e: + error_response = "Please check your tool provider credentials" + agent_tool_callback.on_tool_error(e) + except ( + ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError + ) as e: + error_response = f"there is not a tool named {tool.identity.name}" + agent_tool_callback.on_tool_error(e) + except ( + ToolParameterValidationError + ) as e: + error_response = f"tool parameters validation error: {e}, please check your tool parameters" + agent_tool_callback.on_tool_error(e) + except ToolInvokeError as e: + error_response = f"tool invoke error: {e}" + agent_tool_callback.on_tool_error(e) + except Exception as e: + error_response = f"unknown error: {e}" + agent_tool_callback.on_tool_error(e) + + return error_response, [] + + @staticmethod + def workflow_invoke(tool: Tool, tool_parameters: dict, + user_id: str, workflow_id: str) -> dict: + """ + Workflow invokes the tool with the given arguments. + """ + + @staticmethod + def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: + """ + Handle tool response + """ + result = '' + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.TEXT: + result += response.message + elif response.type == ToolInvokeMessage.MessageType.LINK: + result += f"result link: {response.message}. please tell user to check it." + elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." + else: + result += f"tool response: {response.message}." + + return result + + @staticmethod + def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: + """ + Extract tool response binary + """ + result = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + result.append(ToolInvokeMessageBinary( + mimetype=response.meta.get('mime_type', 'octet/stream'), + url=response.message, + save_as=response.save_as, + )) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append(ToolInvokeMessageBinary( + mimetype=response.meta.get('mime_type', 'octet/stream'), + url=response.message, + save_as=response.save_as, + )) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and 'mime_type' in response.meta: + result.append(ToolInvokeMessageBinary( + mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', + url=response.message, + save_as=response.save_as, + )) + + return result + + @staticmethod + def _create_message_files( + tool_messages: list[ToolInvokeMessageBinary], + agent_message: Message, + invoke_from: InvokeFrom, + user_id: str + ) -> list[tuple[MessageFile, bool]]: + """ + Create message file + + :param messages: messages + :return: message files, should save as variable + """ + result = [] + + for message in tool_messages: + file_type = 'bin' + if 'image' in message.mimetype: + file_type = 'image' + elif 'video' in message.mimetype: + file_type = 'video' + elif 'audio' in message.mimetype: + file_type = 'audio' + elif 'text' in message.mimetype: + file_type = 'text' + elif 'pdf' in message.mimetype: + file_type = 'pdf' + elif 'zip' in message.mimetype: + file_type = 'archive' + # ... + + message_file = MessageFile( + message_id=agent_message.id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE.value, + belongs_to='assistant', + url=message.url, + upload_file_id=None, + created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), + created_by=user_id, + ) + + db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + + result.append(( + message_file, + message.save_as + )) + + db.session.close() + + return result \ No newline at end of file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 3a0278bc01..049dc0a4ed 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,6 @@ from os import listdir, path from typing import Any, Union from core.agent.entities import AgentToolEntity -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.model_runtime.entities.message_entities import PromptMessage from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject @@ -139,8 +138,7 @@ class ToolManager: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') @staticmethod - def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, - agent_callback: DifyAgentCallbackHandler = None) \ + def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \ -> Union[BuiltinTool, ApiTool]: """ get the tool runtime @@ -160,7 +158,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime(meta={ 'tenant_id': tenant_id, 'credentials': {}, - }, agent_callback=agent_callback) + }) # get credentials builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( @@ -182,7 +180,7 @@ class ToolManager: 'tenant_id': tenant_id, 'credentials': decrypted_credentials, 'runtime_parameters': {} - }, agent_callback=agent_callback) + }) elif provider_type == 'api': if tenant_id is None: @@ -259,14 +257,13 @@ class ToolManager: return parameter_value @staticmethod - def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: + def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity) -> Tool: """ get the agent tool runtime """ tool_entity = ToolManager.get_tool_runtime( provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, - agent_callback=agent_callback ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -289,7 +286,7 @@ class ToolManager: return tool_entity @staticmethod - def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler): + def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity): """ get the workflow tool runtime """ @@ -298,7 +295,6 @@ class ToolManager: provider_name=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, - agent_callback=agent_callback ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -364,12 +360,16 @@ class ToolManager: continue # init provider - provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.{provider}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), - parent_type=BuiltinToolProviderController) - builtin_providers.append(provider_class()) + try: + provider_class = load_single_subclass_from_source( + module_name=f'core.tools.provider.builtin.{provider}.{provider}', + script_path=path.join(path.dirname(path.realpath(__file__)), + 'provider', 'builtin', provider, f'{provider}.py'), + parent_type=BuiltinToolProviderController) + builtin_providers.append(provider_class()) + except Exception as e: + logger.error(f'load builtin provider {provider} error: {e}') + continue # cache the builtin providers for provider in builtin_providers: