mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 11:06:46 +08:00
feat: tool node
This commit is contained in:
parent
dcf9d85e8d
commit
8e491ace5c
@ -2,7 +2,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from mimetypes import guess_extension
|
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||||
@ -39,7 +38,6 @@ from core.tools.entities.tool_entities import (
|
|||||||
)
|
)
|
||||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Message, MessageAgentThought, MessageFile
|
from models.model import Message, MessageAgentThought, MessageFile
|
||||||
@ -462,73 +460,6 @@ class BaseAgentRunner(AppRunner):
|
|||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
|
||||||
"""
|
|
||||||
Transform tool message into agent thought
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
|
||||||
result.append(message)
|
|
||||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
|
||||||
result.append(message)
|
|
||||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
|
||||||
# try to download image
|
|
||||||
try:
|
|
||||||
file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
|
|
||||||
conversation_id=self.message.conversation_id,
|
|
||||||
file_url=message.message)
|
|
||||||
|
|
||||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
|
||||||
|
|
||||||
result.append(ToolInvokeMessage(
|
|
||||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
|
||||||
message=url,
|
|
||||||
save_as=message.save_as,
|
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
|
||||||
))
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(e)
|
|
||||||
result.append(ToolInvokeMessage(
|
|
||||||
type=ToolInvokeMessage.MessageType.TEXT,
|
|
||||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
|
||||||
save_as=message.save_as,
|
|
||||||
))
|
|
||||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
|
||||||
# get mime type and save blob to storage
|
|
||||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
|
||||||
# if message is str, encode it to bytes
|
|
||||||
if isinstance(message.message, str):
|
|
||||||
message.message = message.message.encode('utf-8')
|
|
||||||
file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
|
|
||||||
conversation_id=self.message.conversation_id,
|
|
||||||
file_binary=message.message,
|
|
||||||
mimetype=mimetype)
|
|
||||||
|
|
||||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
|
||||||
|
|
||||||
# check if file is image
|
|
||||||
if 'image' in mimetype:
|
|
||||||
result.append(ToolInvokeMessage(
|
|
||||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
|
||||||
message=url,
|
|
||||||
save_as=message.save_as,
|
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
result.append(ToolInvokeMessage(
|
|
||||||
type=ToolInvokeMessage.MessageType.LINK,
|
|
||||||
message=url,
|
|
||||||
save_as=message.save_as,
|
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
result.append(message)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from core.tools.errors import (
|
|||||||
ToolProviderCredentialValidationError,
|
ToolProviderCredentialValidationError,
|
||||||
ToolProviderNotFoundError,
|
ToolProviderNotFoundError,
|
||||||
)
|
)
|
||||||
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from models.model import Conversation, Message
|
from models.model import Conversation, Message
|
||||||
|
|
||||||
|
|
||||||
@ -280,7 +281,12 @@ class CotAgentRunner(BaseAgentRunner):
|
|||||||
tool_parameters=tool_call_args
|
tool_parameters=tool_call_args
|
||||||
)
|
)
|
||||||
# transform tool response to llm friendly response
|
# transform tool response to llm friendly response
|
||||||
tool_response = self.transform_tool_invoke_messages(tool_response)
|
tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||||
|
messages=tool_response,
|
||||||
|
user_id=self.user_id,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
conversation_id=self.message.conversation_id
|
||||||
|
)
|
||||||
# extract binary data from tool invoke message
|
# extract binary data from tool invoke message
|
||||||
binary_files = self.extract_tool_response_binary(tool_response)
|
binary_files = self.extract_tool_response_binary(tool_response)
|
||||||
# create message file
|
# create message file
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from core.tools.errors import (
|
|||||||
ToolProviderCredentialValidationError,
|
ToolProviderCredentialValidationError,
|
||||||
ToolProviderNotFoundError,
|
ToolProviderNotFoundError,
|
||||||
)
|
)
|
||||||
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from models.model import Conversation, Message, MessageAgentThought
|
from models.model import Conversation, Message, MessageAgentThought
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -270,7 +271,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
tool_parameters=tool_call_args,
|
tool_parameters=tool_call_args,
|
||||||
)
|
)
|
||||||
# transform tool invoke message to get LLM friendly message
|
# transform tool invoke message to get LLM friendly message
|
||||||
tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
|
tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||||
|
messages=tool_invoke_message,
|
||||||
|
user_id=self.user_id,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
conversation_id=self.message.conversation_id
|
||||||
|
)
|
||||||
# extract binary data from tool invoke message
|
# extract binary data from tool invoke message
|
||||||
binary_files = self.extract_tool_response_binary(tool_invoke_message)
|
binary_files = self.extract_tool_response_binary(tool_invoke_message)
|
||||||
# create message file
|
# create message file
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from core.tools.utils.configuration import (
|
|||||||
ToolParameterConfigurationManager,
|
ToolParameterConfigurationManager,
|
||||||
)
|
)
|
||||||
from core.tools.utils.encoder import serialize_base_model_dict
|
from core.tools.utils.encoder import serialize_base_model_dict
|
||||||
|
from core.workflow.nodes.tool.entities import ToolEntity
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||||
|
|
||||||
@ -225,6 +226,48 @@ class ToolManager:
|
|||||||
else:
|
else:
|
||||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
|
||||||
|
"""
|
||||||
|
init runtime parameter
|
||||||
|
"""
|
||||||
|
parameter_value = parameters.get(parameter_rule.name)
|
||||||
|
if not parameter_value:
|
||||||
|
# get default value
|
||||||
|
parameter_value = parameter_rule.default
|
||||||
|
if not parameter_value and parameter_rule.required:
|
||||||
|
raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
|
||||||
|
|
||||||
|
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
|
||||||
|
# check if tool_parameter_config in options
|
||||||
|
options = list(map(lambda x: x.value, parameter_rule.options))
|
||||||
|
if parameter_value not in options:
|
||||||
|
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
|
||||||
|
|
||||||
|
# convert tool parameter config to correct type
|
||||||
|
try:
|
||||||
|
if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
|
||||||
|
# check if tool parameter is integer
|
||||||
|
if isinstance(parameter_value, int):
|
||||||
|
parameter_value = parameter_value
|
||||||
|
elif isinstance(parameter_value, float):
|
||||||
|
parameter_value = parameter_value
|
||||||
|
elif isinstance(parameter_value, str):
|
||||||
|
if '.' in parameter_value:
|
||||||
|
parameter_value = float(parameter_value)
|
||||||
|
else:
|
||||||
|
parameter_value = int(parameter_value)
|
||||||
|
elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||||
|
parameter_value = bool(parameter_value)
|
||||||
|
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||||
|
parameter_value = str(parameter_value)
|
||||||
|
elif parameter_rule.type == ToolParameter.ToolParameterType:
|
||||||
|
parameter_value = str(parameter_value)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
|
||||||
|
|
||||||
|
return parameter_value
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
||||||
"""
|
"""
|
||||||
@ -239,44 +282,9 @@ class ToolManager:
|
|||||||
parameters = tool_entity.get_all_runtime_parameters()
|
parameters = tool_entity.get_all_runtime_parameters()
|
||||||
for parameter in parameters:
|
for parameter in parameters:
|
||||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||||
# get tool parameter from form
|
|
||||||
tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
|
|
||||||
if not tool_parameter_config:
|
|
||||||
# get default value
|
|
||||||
tool_parameter_config = parameter.default
|
|
||||||
if not tool_parameter_config and parameter.required:
|
|
||||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
|
||||||
|
|
||||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
|
||||||
# check if tool_parameter_config in options
|
|
||||||
options = list(map(lambda x: x.value, parameter.options))
|
|
||||||
if tool_parameter_config not in options:
|
|
||||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
|
||||||
|
|
||||||
# convert tool parameter config to correct type
|
|
||||||
try:
|
|
||||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
|
||||||
# check if tool parameter is integer
|
|
||||||
if isinstance(tool_parameter_config, int):
|
|
||||||
tool_parameter_config = tool_parameter_config
|
|
||||||
elif isinstance(tool_parameter_config, float):
|
|
||||||
tool_parameter_config = tool_parameter_config
|
|
||||||
elif isinstance(tool_parameter_config, str):
|
|
||||||
if '.' in tool_parameter_config:
|
|
||||||
tool_parameter_config = float(tool_parameter_config)
|
|
||||||
else:
|
|
||||||
tool_parameter_config = int(tool_parameter_config)
|
|
||||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
|
||||||
tool_parameter_config = bool(tool_parameter_config)
|
|
||||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
|
||||||
tool_parameter_config = str(tool_parameter_config)
|
|
||||||
elif parameter.type == ToolParameter.ToolParameterType:
|
|
||||||
tool_parameter_config = str(tool_parameter_config)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
|
||||||
|
|
||||||
# save tool parameter to tool entity memory
|
# save tool parameter to tool entity memory
|
||||||
runtime_parameters[parameter.name] = tool_parameter_config
|
value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||||
|
runtime_parameters[parameter.name] = value
|
||||||
|
|
||||||
# decrypt runtime parameters
|
# decrypt runtime parameters
|
||||||
encryption_manager = ToolParameterConfigurationManager(
|
encryption_manager = ToolParameterConfigurationManager(
|
||||||
@ -289,6 +297,38 @@ class ToolManager:
|
|||||||
|
|
||||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||||
return tool_entity
|
return tool_entity
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler):
|
||||||
|
"""
|
||||||
|
get the workflow tool runtime
|
||||||
|
"""
|
||||||
|
tool_entity = ToolManager.get_tool_runtime(
|
||||||
|
provider_type=workflow_tool.provider_type,
|
||||||
|
provider_name=workflow_tool.provider_id,
|
||||||
|
tool_name=workflow_tool.tool_name,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
agent_callback=agent_callback
|
||||||
|
)
|
||||||
|
runtime_parameters = {}
|
||||||
|
parameters = tool_entity.get_all_runtime_parameters()
|
||||||
|
|
||||||
|
for parameter in parameters:
|
||||||
|
# save tool parameter to tool entity memory
|
||||||
|
value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters)
|
||||||
|
runtime_parameters[parameter.name] = value
|
||||||
|
|
||||||
|
# decrypt runtime parameters
|
||||||
|
encryption_manager = ToolParameterConfigurationManager(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
tool_runtime=tool_entity,
|
||||||
|
provider_name=workflow_tool.provider_id,
|
||||||
|
provider_type=workflow_tool.provider_type,
|
||||||
|
)
|
||||||
|
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||||
|
|
||||||
|
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||||
|
return tool_entity
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
|
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
|
||||||
|
|||||||
85
api/core/tools/utils/message_transformer.py
Normal file
85
api/core/tools/utils/message_transformer.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
from mimetypes import guess_extension
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ToolFileMessageTransformer:
|
||||||
|
@staticmethod
|
||||||
|
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: str) -> list[ToolInvokeMessage]:
|
||||||
|
"""
|
||||||
|
Transform tool message and handle file download
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||||
|
result.append(message)
|
||||||
|
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||||
|
result.append(message)
|
||||||
|
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||||
|
# try to download image
|
||||||
|
try:
|
||||||
|
file = ToolFileManager.create_file_by_url(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
file_url=message.message
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||||
|
|
||||||
|
result.append(ToolInvokeMessage(
|
||||||
|
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=url,
|
||||||
|
save_as=message.save_as,
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
result.append(ToolInvokeMessage(
|
||||||
|
type=ToolInvokeMessage.MessageType.TEXT,
|
||||||
|
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
save_as=message.save_as,
|
||||||
|
))
|
||||||
|
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||||
|
# get mime type and save blob to storage
|
||||||
|
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||||
|
# if message is str, encode it to bytes
|
||||||
|
if isinstance(message.message, str):
|
||||||
|
message.message = message.message.encode('utf-8')
|
||||||
|
|
||||||
|
file = ToolFileManager.create_file_by_raw(
|
||||||
|
user_id=user_id, tenant_id=tenant_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
file_binary=message.message,
|
||||||
|
mimetype=mimetype
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
|
||||||
|
|
||||||
|
# check if file is image
|
||||||
|
if 'image' in mimetype:
|
||||||
|
result.append(ToolInvokeMessage(
|
||||||
|
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=url,
|
||||||
|
save_as=message.save_as,
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
result.append(ToolInvokeMessage(
|
||||||
|
type=ToolInvokeMessage.MessageType.LINK,
|
||||||
|
message=url,
|
||||||
|
save_as=message.save_as,
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
result.append(message)
|
||||||
|
|
||||||
|
return result
|
||||||
23
api/core/workflow/nodes/tool/entities.py
Normal file
23
api/core/workflow/nodes/tool/entities.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
|
||||||
|
ToolParameterValue = Union[str, int, float, bool]
|
||||||
|
|
||||||
|
class ToolEntity(BaseModel):
|
||||||
|
provider_id: str
|
||||||
|
provider_type: Literal['builtin', 'api']
|
||||||
|
provider_name: str # redundancy
|
||||||
|
tool_name: str
|
||||||
|
tool_label: str # redundancy
|
||||||
|
tool_parameters: dict[str, ToolParameterValue]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||||
|
"""
|
||||||
|
Tool Node Schema
|
||||||
|
"""
|
||||||
|
tool_inputs: list[VariableSelector]
|
||||||
@ -1,5 +1,139 @@
|
|||||||
|
from os import path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.file.file_obj import FileTransferMethod
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool_manager import ToolManager
|
||||||
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
pass
|
"""
|
||||||
|
Tool Node
|
||||||
|
"""
|
||||||
|
_node_data_cls = ToolNodeData
|
||||||
|
_node_type = NodeType.TOOL
|
||||||
|
|
||||||
|
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||||
|
"""
|
||||||
|
Run the tool node
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_data = cast(ToolNodeData, self.node_data)
|
||||||
|
|
||||||
|
# extract tool parameters
|
||||||
|
parameters = {
|
||||||
|
k.variable: variable_pool.get_variable_value(k.value_selector)
|
||||||
|
for k in node_data.tool_inputs
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parameters) != len(node_data.tool_inputs):
|
||||||
|
raise ValueError('Invalid tool parameters')
|
||||||
|
|
||||||
|
# get tool runtime
|
||||||
|
try:
|
||||||
|
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None)
|
||||||
|
except Exception as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=parameters,
|
||||||
|
error=f'Failed to get tool runtime: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = tool_runtime.invoke(None, parameters)
|
||||||
|
except Exception as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=parameters,
|
||||||
|
error=f'Failed to invoke tool: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert tool messages
|
||||||
|
plain_text, files = self._convert_tool_messages(messages)
|
||||||
|
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCESS,
|
||||||
|
outputs={
|
||||||
|
'text': plain_text,
|
||||||
|
'files': files
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]:
|
||||||
|
"""
|
||||||
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||||
|
"""
|
||||||
|
# transform message and handle file storage
|
||||||
|
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages)
|
||||||
|
# extract plain text and files
|
||||||
|
files = self._extract_tool_response_binary(messages)
|
||||||
|
plain_text = self._extract_tool_response_text(messages)
|
||||||
|
|
||||||
|
return plain_text, files
|
||||||
|
|
||||||
|
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Extract tool response binary
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for response in tool_response:
|
||||||
|
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||||
|
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||||
|
url = response.message
|
||||||
|
ext = path.splitext(url)[1]
|
||||||
|
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||||
|
filename = response.save_as or url.split('/')[-1]
|
||||||
|
result.append({
|
||||||
|
'type': 'image',
|
||||||
|
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||||
|
'url': url,
|
||||||
|
'upload_file_id': None,
|
||||||
|
'filename': filename,
|
||||||
|
'file-ext': ext,
|
||||||
|
'mime-type': mimetype,
|
||||||
|
})
|
||||||
|
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||||
|
result.append({
|
||||||
|
'type': 'image', # TODO: only support image for now
|
||||||
|
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||||
|
'url': response.message,
|
||||||
|
'upload_file_id': None,
|
||||||
|
'filename': response.save_as,
|
||||||
|
'file-ext': path.splitext(response.save_as)[1],
|
||||||
|
'mime-type': response.meta.get('mime_type', 'application/octet-stream'),
|
||||||
|
})
|
||||||
|
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||||
|
pass # TODO:
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||||
|
"""
|
||||||
|
Extract tool response text
|
||||||
|
"""
|
||||||
|
return ''.join([
|
||||||
|
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||||
|
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||||
|
for message in tool_response
|
||||||
|
])
|
||||||
|
|
||||||
|
def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict:
|
||||||
|
"""
|
||||||
|
Convert ToolInvokeMessage into file
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||||
|
"""
|
||||||
|
Extract variable selector to variable mapping
|
||||||
|
"""
|
||||||
|
pass
|
||||||
Loading…
Reference in New Issue
Block a user