refactor: stream output

This commit is contained in:
Yeuoly 2024-09-10 17:16:55 +08:00
parent b0d53c0ac4
commit cf73374c1b
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
1 changed files with 105 additions and 90 deletions

View File

@ -1,6 +1,7 @@
from collections.abc import Generator, Iterable, Sequence
from collections.abc import Generator, Iterable, Mapping, Sequence
from os import path
from typing import Any, Mapping, cast
from typing import Any, cast
from urllib import response
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -13,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunStreamChunkEvent
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models import WorkflowNodeExecutionStatus
@ -26,7 +28,7 @@ class ToolNode(BaseNode):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self) -> NodeRunResult:
def _run(self) -> Generator[RunEvent]:
"""
Run the tool node
"""
@ -45,22 +47,34 @@ class ToolNode(BaseNode):
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
)
)
return
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True
)
try:
messages = ToolEngine.workflow_invoke(
message_stream = ToolEngine.workflow_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
@ -69,30 +83,33 @@ class ToolNode(BaseNode):
thread_pool_id=self.thread_pool_id,
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to invoke tool: {str(e)}',
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to invoke tool: {str(e)}',
)
)
return
# convert tool messages
plain_text, files, json = self._convert_tool_messages(messages)
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': plain_text,
'files': files,
'json': json
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters_for_log
)
# return NodeRunResult(
# status=WorkflowNodeExecutionStatus.SUCCEEDED,
# outputs={
# 'text': plain_text,
# 'files': files,
# 'json': json
# },
# metadata={
# NodeRunMetadataKey.TOOL_INFO: tool_info
# },
# inputs=parameters_for_log
# )
def _generate_parameters(
self,
@ -148,48 +165,40 @@ class ToolNode(BaseNode):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]):
def _transform_message(self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
result = list(messages)
files: list[FileVar] = []
text = ""
json: list[dict] = []
# extract plain text and files
files = self._extract_tool_response_binary(result)
plain_text = self._extract_tool_response_text(result)
json = self._extract_tool_response_json(result)
for message in message_stream:
if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
message.type == ToolInvokeMessage.MessageType.IMAGE:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
assert response.meta
url = response.message.text
url = message.message.text
ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1]
transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
mimetype = message.meta.get('mime_type', 'image/jpeg')
filename = message.save_as or url.split('/')[-1]
transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = url.split('/')[-1].split('.')[0]
result.append(FileVar(
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
@ -199,48 +208,54 @@ class ToolNode(BaseNode):
extension=ext,
mime_type=mimetype,
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
assert response.meta
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = response.message.text.split('/')[-1].split('.')[0]
result.append(FileVar(
tool_file_id = message.message.text.split('/')[-1].split('.')[0]
files.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
filename=message.save_as,
extension=path.splitext(message.save_as)[1],
mime_type=message.meta.get('mime_type', 'application/octet-stream'),
))
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
return result
def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
result: list[str] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.TEXT:
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result.append(message.message.text)
text += message.message.text + '\n'
yield RunStreamChunkEvent(
chunk_content=message.message.text,
from_variable_selector=[self.node_id, 'text']
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message, ToolInvokeMessage.JsonMessage)
json.append(message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result.append(f'Link: {message.message.text}')
stream_text = f'Link: {message.message.text}\n'
text += stream_text
yield RunStreamChunkEvent(
chunk_content=stream_text,
from_variable_selector=[self.node_id, 'text']
)
return '\n'.join(result)
def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
result: list[dict] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message, ToolInvokeMessage.JsonMessage)
result.append(message.json_object)
return result
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': text,
'files': files,
'json': json
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters_for_log
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(