This commit is contained in:
Yeuoly 2024-03-29 20:53:48 +08:00
parent fb364d44d1
commit 142d1be4f8
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
3 changed files with 27 additions and 16 deletions

View File

@ -1,5 +1,7 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
ToolParameterValue = Union[str, int, float, bool]
@ -16,14 +18,14 @@ class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
value_type: Literal['variable', 'static']
static_value: Optional[Union[int, float, str]]
template_value: Optional[str]
variable_value: Optional[Union[str, list[str]]]
parameter_name: str
@validator('value_type', pre=True, always=True)
def check_value_type(cls, value, values):
if value == 'variable':
# check if template_value is None
if values.get('template_value') is not None:
if values.get('variable_value') is not None:
raise ValueError('template_value must be None for value_type variable')
elif value == 'static':
# check if static_value is None

View File

@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
@ -77,19 +78,18 @@ class ToolNode(BaseNode):
if parameter.value_type == 'static':
result[parameter.parameter_name] = parameter.static_value
else:
parser = VariableTemplateParser(parameter.template_value)
variable_selectors = parser.extract_variable_selectors()
values = {
selector.variable: variable_pool.get_variable_value(selector)
for selector in variable_selectors
}
if isinstance(parameter.variable_value, str):
parser = VariableTemplateParser(parameter.variable_value)
variable_selectors = parser.extract_variable_selectors()
values = {
selector.variable: variable_pool.get_variable_value(selector)
for selector in variable_selectors
}
if len(values) == 1:
# if only one value, use the value directly to avoid type transformation
result[parameter.parameter_name] = list(values.values())[0]
else:
# if multiple values, use the parser to format the values into a string
result[parameter.parameter_name] = parser.format(values)
elif isinstance(parameter.variable_value, list):
result[parameter.parameter_name] = variable_pool.get_variable_value(parameter.variable_value)
return result
@ -161,3 +161,12 @@ class ToolNode(BaseNode):
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View File

@ -27,10 +27,10 @@ def test_tool_invoke():
'tool_configurations': {},
'tool_parameters': [
{
'variable': 'expression',
'value_selector': ['1', '123', 'args1'],
'variable_type': 'selector',
'value': None
'value_type': 'variable',
'static_value': None,
'variable_value': ['1', '123', 'args1'],
'parameter_name': 'expression',
},
]
}