From fb364d44d1701aa9bfce06daba1296053d7acc19 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 29 Mar 2024 20:12:26 +0800 Subject: [PATCH] refactor --- api/core/workflow/nodes/tool/entities.py | 32 +++++++++---------- api/core/workflow/nodes/tool/tool_node.py | 38 +++++++++++++---------- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 7eb3cf655b..13c87dfc20 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,7 +1,5 @@ from typing import Literal, Optional, Union - from pydantic import BaseModel, validator - from core.workflow.entities.base_node_data_entities import BaseNodeData ToolParameterValue = Union[str, int, float, bool] @@ -16,23 +14,23 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): - variable: str - variable_type: Literal['selector', 'static'] - value_selector: Optional[list[str]] - value: Optional[str] + value_type: Literal['variable', 'static'] + static_value: Optional[Union[int, float, str]] + template_value: Optional[str] + parameter_name: str - @validator('value') - def check_value(cls, value, values, **kwargs): - if values['variable_type'] == 'static' and value is None: - raise ValueError('value is required for static variable') + @validator('value_type', pre=True, always=True) + def check_value_type(cls, value, values): + if value == 'variable': + # check if template_value is None + if values.get('template_value') is not None: + raise ValueError('template_value must be None for value_type variable') + elif value == 'static': + # check if static_value is None + if values.get('static_value') is None: + raise ValueError('static_value must be provided for value_type static') return value - - @validator('value_selector') - def check_value_selector(cls, value_selector, values, **kwargs): - if values['variable_type'] == 'selector' and value_selector is None: - raise ValueError('value_selector is required for selector variable') - return value_selector - + """ Tool Node Schema """ diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 003a259243..f2bb24cdef 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -11,6 +11,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.utils.variable_template_parser import VariableTemplateParser from models.workflow import WorkflowNodeExecutionStatus @@ -71,12 +72,26 @@ class ToolNode(BaseNode): """ Generate parameters """ - return { - k.variable: - k.value if k.variable_type == 'static' else - variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else '' - for k in node_data.tool_parameters - } + result = {} + for parameter in node_data.tool_parameters: + if parameter.value_type == 'static': + result[parameter.parameter_name] = parameter.static_value + else: + parser = VariableTemplateParser(parameter.template_value) + variable_selectors = parser.extract_variable_selectors() + values = { + selector.variable: variable_pool.get_variable_value(selector) + for selector in variable_selectors + } + + if len(values) == 1: + # if only one value, use the value directly to avoid type transformation + result[parameter.parameter_name] = list(values.values())[0] + else: + # if multiple values, use the parser to format the values into a string + result[parameter.parameter_name] = parser.format(values) + + return result def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]: """ @@ -146,14 +161,3 @@ class ToolNode(BaseNode): f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else '' for message in tool_response ]) - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: - """ - Extract variable selector to variable mapping - """ - return { - k.variable: k.value_selector - for k in node_data.tool_parameters - if k.variable_type == 'selector' - }