diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 13c87dfc20..42c511302d 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,5 +1,7 @@ from typing import Literal, Optional, Union + from pydantic import BaseModel, validator + from core.workflow.entities.base_node_data_entities import BaseNodeData ToolParameterValue = Union[str, int, float, bool] @@ -16,14 +18,14 @@ class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): value_type: Literal['variable', 'static'] static_value: Optional[Union[int, float, str]] - template_value: Optional[str] + variable_value: Optional[Union[str, list[str]]] parameter_name: str @validator('value_type', pre=True, always=True) def check_value_type(cls, value, values): if value == 'variable': # check if template_value is None - if values.get('template_value') is not None: + if values.get('variable_value') is not None: raise ValueError('template_value must be None for value_type variable') elif value == 'static': # check if static_value is None diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index f2bb24cdef..63be230d6c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode @@ -77,19 +78,18 @@ class ToolNode(BaseNode): if parameter.value_type == 'static': result[parameter.parameter_name] = parameter.static_value else: - parser = VariableTemplateParser(parameter.template_value) - variable_selectors = parser.extract_variable_selectors() - values = { - selector.variable: variable_pool.get_variable_value(selector) - for selector in variable_selectors - } + if isinstance(parameter.variable_value, str): + parser = VariableTemplateParser(parameter.variable_value) + variable_selectors = parser.extract_variable_selectors() + values = { + selector.variable: variable_pool.get_variable_value(selector) + for selector in variable_selectors + } - if len(values) == 1: - # if only one value, use the value directly to avoid type transformation - result[parameter.parameter_name] = list(values.values())[0] - else: # if multiple values, use the parser to format the values into a string result[parameter.parameter_name] = parser.format(values) + elif isinstance(parameter.variable_value, list): + result[parameter.parameter_name] = variable_pool.get_variable_value(parameter.variable_value) return result @@ -161,3 +161,12 @@ class ToolNode(BaseNode): f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' for message in tool_response ]) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} \ No newline at end of file diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 66139563e2..43a0185844 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -27,10 +27,10 @@ def test_tool_invoke(): 'tool_configurations': {}, 'tool_parameters': [ { - 'variable': 'expression', - 'value_selector': ['1', '123', 'args1'], - 'variable_type': 'selector', - 'value': None + 'value_type': 'variable', + 'static_value': None, + 'variable_value': ['1', '123', 'args1'], + 'parameter_name': 'expression', }, ] }