mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 18:27:15 +08:00
refactor
This commit is contained in:
parent
a647698c32
commit
fb364d44d1
@ -1,7 +1,5 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
|
||||||
ToolParameterValue = Union[str, int, float, bool]
|
ToolParameterValue = Union[str, int, float, bool]
|
||||||
@ -16,23 +14,23 @@ class ToolEntity(BaseModel):
|
|||||||
|
|
||||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||||
class ToolInput(BaseModel):
|
class ToolInput(BaseModel):
|
||||||
variable: str
|
value_type: Literal['variable', 'static']
|
||||||
variable_type: Literal['selector', 'static']
|
static_value: Optional[Union[int, float, str]]
|
||||||
value_selector: Optional[list[str]]
|
template_value: Optional[str]
|
||||||
value: Optional[str]
|
parameter_name: str
|
||||||
|
|
||||||
@validator('value')
|
@validator('value_type', pre=True, always=True)
|
||||||
def check_value(cls, value, values, **kwargs):
|
def check_value_type(cls, value, values):
|
||||||
if values['variable_type'] == 'static' and value is None:
|
if value == 'variable':
|
||||||
raise ValueError('value is required for static variable')
|
# check if template_value is None
|
||||||
|
if values.get('template_value') is not None:
|
||||||
|
raise ValueError('template_value must be None for value_type variable')
|
||||||
|
elif value == 'static':
|
||||||
|
# check if static_value is None
|
||||||
|
if values.get('static_value') is None:
|
||||||
|
raise ValueError('static_value must be provided for value_type static')
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@validator('value_selector')
|
|
||||||
def check_value_selector(cls, value_selector, values, **kwargs):
|
|
||||||
if values['variable_type'] == 'selector' and value_selector is None:
|
|
||||||
raise ValueError('value_selector is required for selector variable')
|
|
||||||
return value_selector
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Tool Node Schema
|
Tool Node Schema
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
|||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -71,12 +72,26 @@ class ToolNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
Generate parameters
|
Generate parameters
|
||||||
"""
|
"""
|
||||||
return {
|
result = {}
|
||||||
k.variable:
|
for parameter in node_data.tool_parameters:
|
||||||
k.value if k.variable_type == 'static' else
|
if parameter.value_type == 'static':
|
||||||
variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else ''
|
result[parameter.parameter_name] = parameter.static_value
|
||||||
for k in node_data.tool_parameters
|
else:
|
||||||
}
|
parser = VariableTemplateParser(parameter.template_value)
|
||||||
|
variable_selectors = parser.extract_variable_selectors()
|
||||||
|
values = {
|
||||||
|
selector.variable: variable_pool.get_variable_value(selector)
|
||||||
|
for selector in variable_selectors
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(values) == 1:
|
||||||
|
# if only one value, use the value directly to avoid type transformation
|
||||||
|
result[parameter.parameter_name] = list(values.values())[0]
|
||||||
|
else:
|
||||||
|
# if multiple values, use the parser to format the values into a string
|
||||||
|
result[parameter.parameter_name] = parser.format(values)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||||
"""
|
"""
|
||||||
@ -146,14 +161,3 @@ class ToolNode(BaseNode):
|
|||||||
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||||
for message in tool_response
|
for message in tool_response
|
||||||
])
|
])
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
|
||||||
"""
|
|
||||||
Extract variable selector to variable mapping
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
k.variable: k.value_selector
|
|
||||||
for k in node_data.tool_parameters
|
|
||||||
if k.variable_type == 'selector'
|
|
||||||
}
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user