This commit is contained in:
Yeuoly 2024-03-29 20:12:26 +08:00
parent a647698c32
commit fb364d44d1
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
2 changed files with 36 additions and 34 deletions

View File

@ -1,7 +1,5 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
ToolParameterValue = Union[str, int, float, bool]
@ -16,23 +14,23 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
variable: str
variable_type: Literal['selector', 'static']
value_selector: Optional[list[str]]
value: Optional[str]
value_type: Literal['variable', 'static']
static_value: Optional[Union[int, float, str]]
template_value: Optional[str]
parameter_name: str
@validator('value')
def check_value(cls, value, values, **kwargs):
if values['variable_type'] == 'static' and value is None:
raise ValueError('value is required for static variable')
@validator('value_type', pre=True, always=True)
def check_value_type(cls, value, values):
if value == 'variable':
# check if template_value is None
if values.get('template_value') is not None:
raise ValueError('template_value must be None for value_type variable')
elif value == 'static':
# check if static_value is None
if values.get('static_value') is None:
raise ValueError('static_value must be provided for value_type static')
return value
@validator('value_selector')
def check_value_selector(cls, value_selector, values, **kwargs):
if values['variable_type'] == 'selector' and value_selector is None:
raise ValueError('value_selector is required for selector variable')
return value_selector
"""
Tool Node Schema
"""

View File

@ -11,6 +11,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from models.workflow import WorkflowNodeExecutionStatus
@ -71,12 +72,26 @@ class ToolNode(BaseNode):
"""
Generate parameters
"""
return {
k.variable:
k.value if k.variable_type == 'static' else
variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else ''
for k in node_data.tool_parameters
}
result = {}
for parameter in node_data.tool_parameters:
if parameter.value_type == 'static':
result[parameter.parameter_name] = parameter.static_value
else:
parser = VariableTemplateParser(parameter.template_value)
variable_selectors = parser.extract_variable_selectors()
values = {
selector.variable: variable_pool.get_variable_value(selector)
for selector in variable_selectors
}
if len(values) == 1:
# if only one value, use the value directly to avoid type transformation
result[parameter.parameter_name] = list(values.values())[0]
else:
# if multiple values, use the parser to format the values into a string
result[parameter.parameter_name] = parser.format(values)
return result
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
"""
@ -146,14 +161,3 @@ class ToolNode(BaseNode):
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
"""
return {
k.variable: k.value_selector
for k in node_data.tool_parameters
if k.variable_type == 'selector'
}