From 94f3cf1a4c7ab5955c2bdb348fdb91f6f44779da Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 11 Mar 2024 16:13:52 +0800 Subject: [PATCH] feat: tool entity --- api/core/tools/tool_manager.py | 2 +- api/core/workflow/nodes/tool/entities.py | 19 +++++++++++---- api/core/workflow/nodes/tool/tool_node.py | 29 ++++++++++++----------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ea66362195..52e1e71d82 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -315,7 +315,7 @@ class ToolManager: for parameter in parameters: # save tool parameter to tool entity memory - value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_parameters) + value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations) runtime_parameters[parameter.name] = value # decrypt runtime parameters diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index e782bd3004..0b3bf76aac 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,6 +1,6 @@ -from typing import Literal, Union +from typing import Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -13,11 +13,20 @@ class ToolEntity(BaseModel): provider_name: str # redundancy tool_name: str tool_label: str # redundancy - tool_parameters: dict[str, ToolParameterValue] - + tool_configurations: dict[str, ToolParameterValue] class ToolNodeData(BaseNodeData, ToolEntity): + class ToolInput(VariableSelector): + variable_type: Literal['selector', 'static'] + value: Optional[str] + + @validator('value') + def check_value(cls, value, values, **kwargs): + if values['variable_type'] == 'static' and value is None: + raise ValueError('value is required for static variable') + return value + """ Tool Node Schema """ - tool_inputs: list[VariableSelector] + tool_parameters: list[ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index a0b0991eb6..f1897780f2 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -27,14 +27,8 @@ class ToolNode(BaseNode): node_data = cast(ToolNodeData, self.node_data) - # extract tool parameters - parameters = { - k.variable: variable_pool.get_variable_value(k.value_selector) - for k in node_data.tool_inputs - } - - if len(parameters) != len(node_data.tool_inputs): - raise ValueError('Invalid tool parameters') + # get parameters + parameters = self._generate_parameters(variable_pool, node_data) # get tool runtime try: @@ -47,6 +41,7 @@ class ToolNode(BaseNode): ) try: + # TODO: user_id messages = tool_runtime.invoke(None, parameters) except Exception as e: return NodeRunResult( @@ -59,12 +54,23 @@ class ToolNode(BaseNode): plain_text, files = self._convert_tool_messages(messages) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCESS, + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ 'text': plain_text, 'files': files }, ) + + def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict: + """ + Generate parameters + """ + return { + k.variable: + k.value if k.variable_type == 'static' else + variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else '' + for k in node_data.tool_parameters + } def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]: """ @@ -125,11 +131,6 @@ class ToolNode(BaseNode): for message in tool_response ]) - def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict: - """ - Convert ToolInvokeMessage into file - """ - pass @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]: