diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 0774fc4f3d..47498f4f5f 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -128,8 +128,10 @@ class ToolNode(BaseNode): else: tool_input = node_data.tool_parameters[parameter_name] if tool_input.type == 'variable': - # TODO: check if the variable exists in the variable pool - parameter_value = variable_pool.get(tool_input.value).value + parameter_value_segment = variable_pool.get(tool_input.value) + if not parameter_value_segment: + raise Exception("input variable dose not exists") + parameter_value = parameter_value_segment.value else: segment_group = parser.convert_template( template=str(tool_input.value), @@ -163,7 +165,7 @@ class ToolNode(BaseNode): return plain_text, files, json - def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]: + def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]: """ Extract tool response binary """ @@ -172,7 +174,10 @@ class ToolNode(BaseNode): for response in tool_response: if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: - url = response.message + assert isinstance(response.message, ToolInvokeMessage.TextMessage) + assert response.meta + + url = response.message.text ext = path.splitext(url)[1] mimetype = response.meta.get('mime_type', 'image/jpeg') filename = response.save_as or url.split('/')[-1] @@ -192,7 +197,10 @@ class ToolNode(BaseNode): )) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - tool_file_id = response.message.split('/')[-1].split('.')[0] + assert isinstance(response.message, ToolInvokeMessage.TextMessage) + assert response.meta + + tool_file_id = response.message.text.split('/')[-1].split('.')[0] result.append(FileVar( tenant_id=self.tenant_id, type=FileType.IMAGE, @@ -207,18 +215,28 @@ class ToolNode(BaseNode): return result - def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str: + def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str: """ Extract tool response text """ - return '\n'.join([ - f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else - f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' - for message in tool_response - ]) + result: list[str] = [] + for message in tool_response: + if message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + result.append(message.message.text) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + result.append(f'Link: {message.message.text}') - def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: - return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] + return '\n'.join(result) + + def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]: + result: list[dict] = [] + for message in tool_response: + if message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message, ToolInvokeMessage.JsonMessage) + result.append(message.json_object) + return result @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: @@ -231,6 +249,7 @@ class ToolNode(BaseNode): for parameter_name in node_data.tool_parameters: input = node_data.tool_parameters[parameter_name] if input.type == 'mixed': + assert isinstance(input.value, str) selectors = VariableTemplateParser(input.value).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector