diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0901b7e965..14602a7265 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -2,7 +2,6 @@ import json import logging import uuid from datetime import datetime -from mimetypes import guess_extension from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity @@ -39,7 +38,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool -from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_manager import ToolManager from extensions.ext_database import db from models.model import Message, MessageAgentThought, MessageFile @@ -462,73 +460,6 @@ class BaseAgentRunner(AppRunner): db.session.commit() db.session.close() - - def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: - """ - Transform tool message into agent thought - """ - result = [] - - for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.LINK: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.IMAGE: - # try to download image - try: - file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_url=message.message) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - except Exception as e: - logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') - # if message is str, encode it to bytes - if isinstance(message.message, str): - message.message = message.message.encode('utf-8') - file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id, - conversation_id=self.message.conversation_id, - file_binary=message.message, - mimetype=mimetype) - - url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' - - # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - else: - result.append(message) - - return result def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index cbb19aca53..0c5399f541 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,6 +25,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message @@ -280,7 +281,12 @@ class CotAgentRunner(BaseAgentRunner): tool_parameters=tool_call_args ) # transform tool response to llm friendly response - tool_response = self.transform_tool_invoke_messages(tool_response) + tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_response, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_response) # create message file diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 7c3849a12c..185d7684c8 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -23,6 +23,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) +from core.tools.utils.message_transformer import ToolFileMessageTransformer from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -270,7 +271,12 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_parameters=tool_call_args, ) # transform tool invoke message to get LLM friendly message - tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message) + tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=tool_invoke_message, + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=self.message.conversation_id + ) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_invoke_message) # create message file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 24b2f287c1..ea66362195 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -34,6 +34,7 @@ from core.tools.utils.configuration import ( ToolParameterConfigurationManager, ) from core.tools.utils.encoder import serialize_base_model_dict +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -225,6 +226,48 @@ class ToolManager: else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + @staticmethod + def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + """ + init runtime parameter + """ + parameter_value = parameters.get(parameter_rule.name) + if not parameter_value: + # get default value + parameter_value = parameter_rule.default + if not parameter_value and parameter_rule.required: + raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config") + + if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: + # check if tool_parameter_config in options + options = list(map(lambda x: x.value, parameter_rule.options)) + if parameter_value not in options: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + + # convert tool parameter config to correct type + try: + if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER: + # check if tool parameter is integer + if isinstance(parameter_value, int): + parameter_value = parameter_value + elif isinstance(parameter_value, float): + parameter_value = parameter_value + elif isinstance(parameter_value, str): + if '.' in parameter_value: + parameter_value = float(parameter_value) + else: + parameter_value = int(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN: + parameter_value = bool(parameter_value) + elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: + parameter_value = str(parameter_value) + elif parameter_rule.type == ToolParameter.ToolParameterType: + parameter_value = str(parameter_value) + except Exception as e: + raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type") + + return parameter_value + @staticmethod def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: """ @@ -239,44 +282,9 @@ class ToolManager: parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM: - # get tool parameter from form - tool_parameter_config = agent_tool.tool_parameters.get(parameter.name) - if not tool_parameter_config: - # get default value - tool_parameter_config = parameter.default - if not tool_parameter_config and parameter.required: - raise ValueError(f"tool parameter {parameter.name} not found in tool config") - - if parameter.type == ToolParameter.ToolParameterType.SELECT: - # check if tool_parameter_config in options - options = list(map(lambda x: x.value, parameter.options)) - if tool_parameter_config not in options: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") - - # convert tool parameter config to correct type - try: - if parameter.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(tool_parameter_config, int): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, float): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, str): - if '.' in tool_parameter_config: - tool_parameter_config = float(tool_parameter_config) - else: - tool_parameter_config = int(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - tool_parameter_config = bool(tool_parameter_config) - elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: - tool_parameter_config = str(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType: - tool_parameter_config = str(tool_parameter_config) - except Exception as e: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") - # save tool parameter to tool entity memory - runtime_parameters[parameter.name] = tool_parameter_config + value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters) + runtime_parameters[parameter.name] = value # decrypt runtime parameters encryption_manager = ToolParameterConfigurationManager( @@ -289,6 +297,38 @@ class ToolManager: tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity + + @staticmethod + def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler): + """ + get the workflow tool runtime + """ + tool_entity = ToolManager.get_tool_runtime( + provider_type=workflow_tool.provider_type, + provider_name=workflow_tool.provider_id, + tool_name=workflow_tool.tool_name, + tenant_id=tenant_id, + agent_callback=agent_callback + ) + runtime_parameters = {} + parameters = tool_entity.get_all_runtime_parameters() + + for parameter in parameters: + # save tool parameter to tool entity memory + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters) + runtime_parameters[parameter.name] = value + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=workflow_tool.provider_id, + provider_type=workflow_tool.provider_type, + ) + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity @staticmethod def get_builtin_provider_icon(provider: str) -> tuple[str, str]: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py new file mode 100644 index 0000000000..3f456b4eb6 --- /dev/null +++ b/api/core/tools/utils/message_transformer.py @@ -0,0 +1,85 @@ +import logging +from mimetypes import guess_extension + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_file_manager import ToolFileManager + +logger = logging.getLogger(__name__) + +class ToolFileMessageTransformer: + @staticmethod + def transform_tool_invoke_messages(messages: list[ToolInvokeMessage], + user_id: str, + tenant_id: str, + conversation_id: str) -> list[ToolInvokeMessage]: + """ + Transform tool message and handle file download + """ + result = [] + + for message in messages: + if message.type == ToolInvokeMessage.MessageType.TEXT: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.LINK: + result.append(message) + elif message.type == ToolInvokeMessage.MessageType.IMAGE: + # try to download image + try: + file = ToolFileManager.create_file_by_url( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_url=message.message + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' + + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + except Exception as e: + logger.exception(e) + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, you can try to download it yourself.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + )) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get mime type and save blob to storage + mimetype = message.meta.get('mime_type', 'octet/stream') + # if message is str, encode it to bytes + if isinstance(message.message, str): + message.message = message.message.encode('utf-8') + + file = ToolFileManager.create_file_by_raw( + user_id=user_id, tenant_id=tenant_id, + conversation_id=conversation_id, + file_binary=message.message, + mimetype=mimetype + ) + + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}' + + # check if file is image + if 'image' in mimetype: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + )) + else: + result.append(message) + + return result \ No newline at end of file diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py new file mode 100644 index 0000000000..e782bd3004 --- /dev/null +++ b/api/core/workflow/nodes/tool/entities.py @@ -0,0 +1,23 @@ +from typing import Literal, Union + +from pydantic import BaseModel + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.variable_entities import VariableSelector + +ToolParameterValue = Union[str, int, float, bool] + +class ToolEntity(BaseModel): + provider_id: str + provider_type: Literal['builtin', 'api'] + provider_name: str # redundancy + tool_name: str + tool_label: str # redundancy + tool_parameters: dict[str, ToolParameterValue] + + +class ToolNodeData(BaseNodeData, ToolEntity): + """ + Tool Node Schema + """ + tool_inputs: list[VariableSelector] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index b805a53d2f..a0b0991eb6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,139 @@ +from os import path +from typing import cast + +from core.file.file_obj import FileTransferMethod +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.tool.entities import ToolNodeData +from models.workflow import WorkflowNodeExecutionStatus class ToolNode(BaseNode): - pass + """ + Tool Node + """ + _node_data_cls = ToolNodeData + _node_type = NodeType.TOOL + + def _run(self, variable_pool: VariablePool) -> NodeRunResult: + """ + Run the tool node + """ + + node_data = cast(ToolNodeData, self.node_data) + + # extract tool parameters + parameters = { + k.variable: variable_pool.get_variable_value(k.value_selector) + for k in node_data.tool_inputs + } + + if len(parameters) != len(node_data.tool_inputs): + raise ValueError('Invalid tool parameters') + + # get tool runtime + try: + tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to get tool runtime: {str(e)}' + ) + + try: + messages = tool_runtime.invoke(None, parameters) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters, + error=f'Failed to invoke tool: {str(e)}' + ) + + # convert tool messages + plain_text, files = self._convert_tool_messages(messages) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCESS, + outputs={ + 'text': plain_text, + 'files': files + }, + ) + + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages) + # extract plain text and files + files = self._extract_tool_response_binary(messages) + plain_text = self._extract_tool_response_text(messages) + + return plain_text, files + + def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: + """ + Extract tool response binary + """ + result = [] + + for response in tool_response: + if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + response.type == ToolInvokeMessage.MessageType.IMAGE: + url = response.message + ext = path.splitext(url)[1] + mimetype = response.meta.get('mime_type', 'image/jpeg') + filename = response.save_as or url.split('/')[-1] + result.append({ + 'type': 'image', + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': url, + 'upload_file_id': None, + 'filename': filename, + 'file-ext': ext, + 'mime-type': mimetype, + }) + elif response.type == ToolInvokeMessage.MessageType.BLOB: + result.append({ + 'type': 'image', # TODO: only support image for now + 'transfer_method': FileTransferMethod.TOOL_FILE, + 'url': response.message, + 'upload_file_id': None, + 'filename': response.save_as, + 'file-ext': path.splitext(response.save_as)[1], + 'mime-type': response.meta.get('mime_type', 'application/octet-stream'), + }) + elif response.type == ToolInvokeMessage.MessageType.LINK: + pass # TODO: + + return result + + def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str: + """ + Extract tool response text + """ + return ''.join([ + f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else + f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' + for message in tool_response + ]) + + def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict: + """ + Convert ToolInvokeMessage into file + """ + pass + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: + """ + Extract variable selector to variable mapping + """ + pass \ No newline at end of file