mirror of https://github.com/langgenius/dify.git
refactor
This commit is contained in:
parent
fb364d44d1
commit
142d1be4f8
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
ToolParameterValue = Union[str, int, float, bool]
|
||||
|
|
@ -16,14 +18,14 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
|||
class ToolInput(BaseModel):
|
||||
value_type: Literal['variable', 'static']
|
||||
static_value: Optional[Union[int, float, str]]
|
||||
template_value: Optional[str]
|
||||
variable_value: Optional[Union[str, list[str]]]
|
||||
parameter_name: str
|
||||
|
||||
@validator('value_type', pre=True, always=True)
|
||||
def check_value_type(cls, value, values):
|
||||
if value == 'variable':
|
||||
# check if template_value is None
|
||||
if values.get('template_value') is not None:
|
||||
if values.get('variable_value') is not None:
|
||||
raise ValueError('template_value must be None for value_type variable')
|
||||
elif value == 'static':
|
||||
# check if static_value is None
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
|||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
|
@ -77,19 +78,18 @@ class ToolNode(BaseNode):
|
|||
if parameter.value_type == 'static':
|
||||
result[parameter.parameter_name] = parameter.static_value
|
||||
else:
|
||||
parser = VariableTemplateParser(parameter.template_value)
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
values = {
|
||||
selector.variable: variable_pool.get_variable_value(selector)
|
||||
for selector in variable_selectors
|
||||
}
|
||||
if isinstance(parameter.variable_value, str):
|
||||
parser = VariableTemplateParser(parameter.variable_value)
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
values = {
|
||||
selector.variable: variable_pool.get_variable_value(selector)
|
||||
for selector in variable_selectors
|
||||
}
|
||||
|
||||
if len(values) == 1:
|
||||
# if only one value, use the value directly to avoid type transformation
|
||||
result[parameter.parameter_name] = list(values.values())[0]
|
||||
else:
|
||||
# if multiple values, use the parser to format the values into a string
|
||||
result[parameter.parameter_name] = parser.format(values)
|
||||
elif isinstance(parameter.variable_value, list):
|
||||
result[parameter.parameter_name] = variable_pool.get_variable_value(parameter.variable_value)
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -161,3 +161,12 @@ class ToolNode(BaseNode):
|
|||
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
|
@ -27,10 +27,10 @@ def test_tool_invoke():
|
|||
'tool_configurations': {},
|
||||
'tool_parameters': [
|
||||
{
|
||||
'variable': 'expression',
|
||||
'value_selector': ['1', '123', 'args1'],
|
||||
'variable_type': 'selector',
|
||||
'value': None
|
||||
'value_type': 'variable',
|
||||
'static_value': None,
|
||||
'variable_value': ['1', '123', 'args1'],
|
||||
'parameter_name': 'expression',
|
||||
},
|
||||
]
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue