From 5b81234db8af0540645b7f14760185b9f7c5679a Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 1 Apr 2024 16:43:10 +0800 Subject: [PATCH] fix: tool entities --- api/core/workflow/nodes/tool/entities.py | 32 ++++++------- api/core/workflow/nodes/tool/tool_node.py | 56 ++++++++++++----------- 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 42c511302d..ebaa7a56bd 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal, Union from pydantic import BaseModel, validator @@ -16,24 +16,22 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): - value_type: Literal['variable', 'static'] - static_value: Optional[Union[int, float, str]] - variable_value: Optional[Union[str, list[str]]] - parameter_name: str + type: Literal['mixed', 'variable', 'constant'] + value: Union[ToolParameterValue, list[str]] - @validator('value_type', pre=True, always=True) - def check_value_type(cls, value, values): - if value == 'variable': - # check if template_value is None - if values.get('variable_value') is not None: - raise ValueError('template_value must be None for value_type variable') - elif value == 'static': - # check if static_value is None - if values.get('static_value') is None: - raise ValueError('static_value must be provided for value_type static') + @validator('type', pre=True, always=True) + def check_type(cls, value, values): + typ = value + value = values.get('value') + if typ == 'mixed' and not isinstance(value, str): + raise ValueError('value must be a string') + elif typ == 'variable' and not isinstance(value, list): + raise ValueError('value must be a list') + elif typ == 'constant' and not isinstance(value, ToolParameterValue): + raise ValueError('value must be a string, int, float, or bool') return value - + """ Tool Node Schema """ - tool_parameters: list[ToolInput] + tool_parameters: dict[str, ToolInput] diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 83e51a04c0..8a67284971 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -88,24 +88,27 @@ class ToolNode(BaseNode): Generate parameters """ result = {} - for parameter in node_data.tool_parameters: - if parameter.value_type == 'static': - result[parameter.parameter_name] = parameter.static_value - else: - if isinstance(parameter.variable_value, str): - parser = VariableTemplateParser(parameter.variable_value) - variable_selectors = parser.extract_variable_selectors() - values = { - selector.variable: variable_pool.get_variable_value(selector) - for selector in variable_selectors - } - - # if multiple values, use the parser to format the values into a string - result[parameter.parameter_name] = parser.format(values) - elif isinstance(parameter.variable_value, list): - result[parameter.parameter_name] = variable_pool.get_variable_value(parameter.variable_value) + for parameter_name in node_data.tool_parameters: + input = node_data.tool_parameters[parameter_name] + if input.type == 'mixed': + result[parameter_name] = self._format_variable_template(input.value, variable_pool) + elif input.type == 'variable': + result[parameter_name] = variable_pool.get_variable_value(input.value) + elif input.type == 'constant': + result[parameter_name] = input.value return result + + def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str: + """ + Format variable template + """ + inputs = {} + template_parser = VariableTemplateParser(template) + for selector in template_parser.extract_variable_selectors(): + inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector) + + return template_parser.format(inputs) def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]: """ @@ -184,14 +187,15 @@ class ToolNode(BaseNode): :return: """ result = {} - for parameter in node_data.tool_parameters: - if parameter.value_type == 'variable': - if isinstance(parameter.variable_value, str): - parser = VariableTemplateParser(parameter.variable_value) - variable_selectors = parser.extract_variable_selectors() - for selector in variable_selectors: - result[selector.variable] = selector.value_selector - elif isinstance(parameter.variable_value, list): - result[parameter.parameter_name] = parameter.variable_value + for parameter_name in node_data.tool_parameters: + input = node_data.tool_parameters[parameter_name] + if input.type == 'mixed': + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + elif input.type == 'variable': + result[parameter_name] = input.value + elif input.type == 'constant': + pass - return result \ No newline at end of file + return result