mirror of https://github.com/langgenius/dify.git
refactor
This commit is contained in:
parent
a647698c32
commit
fb364d44d1
|
|
@ -1,7 +1,5 @@
|
|||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
ToolParameterValue = Union[str, int, float, bool]
|
||||
|
|
@ -16,23 +14,23 @@ class ToolEntity(BaseModel):
|
|||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
variable: str
|
||||
variable_type: Literal['selector', 'static']
|
||||
value_selector: Optional[list[str]]
|
||||
value: Optional[str]
|
||||
value_type: Literal['variable', 'static']
|
||||
static_value: Optional[Union[int, float, str]]
|
||||
template_value: Optional[str]
|
||||
parameter_name: str
|
||||
|
||||
@validator('value')
|
||||
def check_value(cls, value, values, **kwargs):
|
||||
if values['variable_type'] == 'static' and value is None:
|
||||
raise ValueError('value is required for static variable')
|
||||
@validator('value_type', pre=True, always=True)
|
||||
def check_value_type(cls, value, values):
|
||||
if value == 'variable':
|
||||
# check if template_value is None
|
||||
if values.get('template_value') is not None:
|
||||
raise ValueError('template_value must be None for value_type variable')
|
||||
elif value == 'static':
|
||||
# check if static_value is None
|
||||
if values.get('static_value') is None:
|
||||
raise ValueError('static_value must be provided for value_type static')
|
||||
return value
|
||||
|
||||
@validator('value_selector')
|
||||
def check_value_selector(cls, value_selector, values, **kwargs):
|
||||
if values['variable_type'] == 'selector' and value_selector is None:
|
||||
raise ValueError('value_selector is required for selector variable')
|
||||
return value_selector
|
||||
|
||||
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
|
@ -71,12 +72,26 @@ class ToolNode(BaseNode):
|
|||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
return {
|
||||
k.variable:
|
||||
k.value if k.variable_type == 'static' else
|
||||
variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else ''
|
||||
for k in node_data.tool_parameters
|
||||
}
|
||||
result = {}
|
||||
for parameter in node_data.tool_parameters:
|
||||
if parameter.value_type == 'static':
|
||||
result[parameter.parameter_name] = parameter.static_value
|
||||
else:
|
||||
parser = VariableTemplateParser(parameter.template_value)
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
values = {
|
||||
selector.variable: variable_pool.get_variable_value(selector)
|
||||
for selector in variable_selectors
|
||||
}
|
||||
|
||||
if len(values) == 1:
|
||||
# if only one value, use the value directly to avoid type transformation
|
||||
result[parameter.parameter_name] = list(values.values())[0]
|
||||
else:
|
||||
# if multiple values, use the parser to format the values into a string
|
||||
result[parameter.parameter_name] = parser.format(values)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
"""
|
||||
|
|
@ -146,14 +161,3 @@ class ToolNode(BaseNode):
|
|||
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
"""
|
||||
return {
|
||||
k.variable: k.value_selector
|
||||
for k in node_data.tool_parameters
|
||||
if k.variable_type == 'selector'
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue