mirror of https://github.com/langgenius/dify.git
refactor: stream output
This commit is contained in:
parent
b0d53c0ac4
commit
cf73374c1b
|
|
@ -1,6 +1,7 @@
|
|||
from collections.abc import Generator, Iterable, Sequence
|
||||
from collections.abc import Generator, Iterable, Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any, Mapping, cast
|
||||
from typing import Any, cast
|
||||
from urllib import response
|
||||
|
||||
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
|
|
@ -13,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models import WorkflowNodeExecutionStatus
|
||||
|
|
@ -26,7 +28,7 @@ class ToolNode(BaseNode):
|
|||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
def _run(self) -> Generator[RunEvent]:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
|
@ -45,22 +47,34 @@ class ToolNode(BaseNode):
|
|||
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
|
||||
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
for_log=True
|
||||
)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
message_stream = ToolEngine.workflow_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
|
|
@ -69,30 +83,33 @@ class ToolNode(BaseNode):
|
|||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to invoke tool: {str(e)}',
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to invoke tool: {str(e)}',
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# convert tool messages
|
||||
plain_text, files, json = self._convert_tool_messages(messages)
|
||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'text': plain_text,
|
||||
'files': files,
|
||||
'json': json
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
inputs=parameters_for_log
|
||||
)
|
||||
# return NodeRunResult(
|
||||
# status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# outputs={
|
||||
# 'text': plain_text,
|
||||
# 'files': files,
|
||||
# 'json': json
|
||||
# },
|
||||
# metadata={
|
||||
# NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
# },
|
||||
# inputs=parameters_for_log
|
||||
# )
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
|
|
@ -148,48 +165,40 @@ class ToolNode(BaseNode):
|
|||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]):
|
||||
def _transform_message(self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any]) -> Generator[RunEvent, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
result = list(messages)
|
||||
files: list[FileVar] = []
|
||||
text = ""
|
||||
json: list[dict] = []
|
||||
|
||||
# extract plain text and files
|
||||
files = self._extract_tool_response_binary(result)
|
||||
plain_text = self._extract_tool_response_text(result)
|
||||
json = self._extract_tool_response_json(result)
|
||||
for message in message_stream:
|
||||
if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
return plain_text, files, json
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
assert response.meta
|
||||
|
||||
url = response.message.text
|
||||
url = message.message.text
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||
filename = response.save_as or url.split('/')[-1]
|
||||
transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
|
||||
mimetype = message.meta.get('mime_type', 'image/jpeg')
|
||||
filename = message.save_as or url.split('/')[-1]
|
||||
transfer_method = message.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split('/')[-1].split('.')[0]
|
||||
result.append(FileVar(
|
||||
files.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
|
|
@ -199,48 +208,54 @@ class ToolNode(BaseNode):
|
|||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
assert response.meta
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = response.message.text.split('/')[-1].split('.')[0]
|
||||
result.append(FileVar(
|
||||
tool_file_id = message.message.text.split('/')[-1].split('.')[0]
|
||||
files.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
|
||||
filename=message.save_as,
|
||||
extension=path.splitext(message.save_as)[1],
|
||||
mime_type=message.meta.get('mime_type', 'application/octet-stream'),
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
|
||||
return result
|
||||
|
||||
def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
result: list[str] = []
|
||||
for message in tool_response:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
result.append(message.message.text)
|
||||
text += message.message.text + '\n'
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message, ToolInvokeMessage.JsonMessage)
|
||||
json.append(message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
result.append(f'Link: {message.message.text}')
|
||||
stream_text = f'Link: {message.message.text}\n'
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=stream_text,
|
||||
from_variable_selector=[self.node_id, 'text']
|
||||
)
|
||||
|
||||
return '\n'.join(result)
|
||||
|
||||
def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
|
||||
result: list[dict] = []
|
||||
for message in tool_response:
|
||||
if message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message, ToolInvokeMessage.JsonMessage)
|
||||
result.append(message.json_object)
|
||||
return result
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'text': text,
|
||||
'files': files,
|
||||
'json': json
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
inputs=parameters_for_log
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
|
|
|||
Loading…
Reference in New Issue