diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index bb4649eaa4..3865695c71 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,6 +1,7 @@ -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Generator, Iterable, Mapping, Sequence from os import path -from typing import Any, Mapping, cast +from typing import Any, cast +from urllib import response from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler @@ -13,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunStreamChunkEvent from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser from models import WorkflowNodeExecutionStatus @@ -26,7 +28,7 @@ class ToolNode(BaseNode): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[RunEvent]: """ Run the tool node """ @@ -45,22 +47,34 @@ class ToolNode(BaseNode): self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from ) except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to get tool runtime: {str(e)}' + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + metadata={ + NodeRunMetadataKey.TOOL_INFO: tool_info + }, + error=f'Failed to get tool runtime: {str(e)}' + ) ) + return # get parameters tool_parameters = tool_runtime.get_runtime_parameters() or [] - parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data) - parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True) + parameters = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=node_data + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=node_data, + for_log=True + ) try: - messages = ToolEngine.workflow_invoke( + message_stream = ToolEngine.workflow_invoke( tool=tool_runtime, tool_parameters=parameters, user_id=self.user_id, @@ -69,30 +83,33 @@ class ToolNode(BaseNode): thread_pool_id=self.thread_pool_id, ) except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - error=f'Failed to invoke tool: {str(e)}', + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + metadata={ + NodeRunMetadataKey.TOOL_INFO: tool_info + }, + error=f'Failed to invoke tool: {str(e)}', + ) ) + return # convert tool messages - plain_text, files, json = self._convert_tool_messages(messages) + yield from self._transform_message(message_stream, tool_info, parameters_for_log) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'text': plain_text, - 'files': files, - 'json': json - }, - metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info - }, - inputs=parameters_for_log - ) + # return NodeRunResult( + # status=WorkflowNodeExecutionStatus.SUCCEEDED, + # outputs={ + # 'text': plain_text, + # 'files': files, + # 'json': json + # }, + # metadata={ + # NodeRunMetadataKey.TOOL_INFO: tool_info + # }, + # inputs=parameters_for_log + # ) def _generate_parameters( self, @@ -148,48 +165,40 @@ class ToolNode(BaseNode): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]): + def _transform_message(self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ # transform message and handle file storage - messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( messages=messages, user_id=self.user_id, tenant_id=self.tenant_id, conversation_id=None, ) - result = list(messages) + files: list[FileVar] = [] + text = "" + json: list[dict] = [] - # extract plain text and files - files = self._extract_tool_response_binary(result) - plain_text = self._extract_tool_response_text(result) - json = self._extract_tool_response_json(result) + for message in message_stream: + if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ + message.type == ToolInvokeMessage.MessageType.IMAGE: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta - return plain_text, files, json - - def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]: - """ - Extract tool response binary - """ - result = [] - - for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: - assert isinstance(response.message, ToolInvokeMessage.TextMessage) - assert response.meta - - url = response.message.text + url = message.message.text ext = path.splitext(url)[1] - mimetype = response.meta.get('mime_type', 'image/jpeg') - filename = response.save_as or url.split('/')[-1] - transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) + mimetype = message.meta.get('mime_type', 'image/jpeg') + filename = message.save_as or url.split('/')[-1] + transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) # get tool file id tool_file_id = url.split('/')[-1].split('.')[0] - result.append(FileVar( + files.append(FileVar( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=transfer_method, @@ -199,48 +208,54 @@ class ToolNode(BaseNode): extension=ext, mime_type=mimetype, )) - elif response.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - assert isinstance(response.message, ToolInvokeMessage.TextMessage) - assert response.meta + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta - tool_file_id = response.message.text.split('/')[-1].split('.')[0] - result.append(FileVar( + tool_file_id = message.message.text.split('/')[-1].split('.')[0] + files.append(FileVar( tenant_id=self.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=tool_file_id, - filename=response.save_as, - extension=path.splitext(response.save_as)[1], - mime_type=response.meta.get('mime_type', 'application/octet-stream'), + filename=message.save_as, + extension=path.splitext(message.save_as)[1], + mime_type=message.meta.get('mime_type', 'application/octet-stream'), )) - elif response.type == ToolInvokeMessage.MessageType.LINK: - pass # TODO: - - return result - - def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str: - """ - Extract tool response text - """ - result: list[str] = [] - for message in tool_response: - if message.type == ToolInvokeMessage.MessageType.TEXT: + elif message.type == ToolInvokeMessage.MessageType.TEXT: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - result.append(message.message.text) + text += message.message.text + '\n' + yield RunStreamChunkEvent( + chunk_content=message.message.text, + from_variable_selector=[self.node_id, 'text'] + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message, ToolInvokeMessage.JsonMessage) + json.append(message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - result.append(f'Link: {message.message.text}') + stream_text = f'Link: {message.message.text}\n' + text += stream_text + yield RunStreamChunkEvent( + chunk_content=stream_text, + from_variable_selector=[self.node_id, 'text'] + ) - return '\n'.join(result) - - def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]: - result: list[dict] = [] - for message in tool_response: - if message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message, ToolInvokeMessage.JsonMessage) - result.append(message.json_object) - return result + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + 'text': text, + 'files': files, + 'json': json + }, + metadata={ + NodeRunMetadataKey.TOOL_INFO: tool_info + }, + inputs=parameters_for_log + ) + ) @classmethod def _extract_variable_selector_to_variable_mapping(