mirror of https://github.com/langgenius/dify.git
Refactor workflow nodes to use generic node_data (#28782)
This commit is contained in:
parent
002d8769b0
commit
8b761319f6
|
|
@ -70,7 +70,6 @@ class AgentNode(Node[AgentNodeData]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_type = NodeType.AGENT
|
node_type = NodeType.AGENT
|
||||||
_node_data: AgentNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
|
|
@ -82,8 +81,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||||
try:
|
try:
|
||||||
strategy = get_plugin_agent_strategy(
|
strategy = get_plugin_agent_strategy(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
|
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield StreamCompletedEvent(
|
yield StreamCompletedEvent(
|
||||||
|
|
@ -101,13 +100,13 @@ class AgentNode(Node[AgentNodeData]):
|
||||||
parameters = self._generate_agent_parameters(
|
parameters = self._generate_agent_parameters(
|
||||||
agent_parameters=agent_parameters,
|
agent_parameters=agent_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self._node_data,
|
node_data=self.node_data,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
)
|
)
|
||||||
parameters_for_log = self._generate_agent_parameters(
|
parameters_for_log = self._generate_agent_parameters(
|
||||||
agent_parameters=agent_parameters,
|
agent_parameters=agent_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self._node_data,
|
node_data=self.node_data,
|
||||||
for_log=True,
|
for_log=True,
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
)
|
)
|
||||||
|
|
@ -140,7 +139,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||||
messages=message_stream,
|
messages=message_stream,
|
||||||
tool_info={
|
tool_info={
|
||||||
"icon": self.agent_strategy_icon,
|
"icon": self.agent_strategy_icon,
|
||||||
"agent_strategy": self._node_data.agent_strategy_name,
|
"agent_strategy": self.node_data.agent_strategy_name,
|
||||||
},
|
},
|
||||||
parameters_for_log=parameters_for_log,
|
parameters_for_log=parameters_for_log,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
|
|
@ -387,7 +386,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||||
current_plugin = next(
|
current_plugin = next(
|
||||||
plugin
|
plugin
|
||||||
for plugin in plugins
|
for plugin in plugins
|
||||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||||
)
|
)
|
||||||
icon = current_plugin.declaration.icon
|
icon = current_plugin.declaration.icon
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
|
|
||||||
|
|
@ -14,14 +14,12 @@ class AnswerNode(Node[AnswerNodeData]):
|
||||||
node_type = NodeType.ANSWER
|
node_type = NodeType.ANSWER
|
||||||
execution_type = NodeExecutionType.RESPONSE
|
execution_type = NodeExecutionType.RESPONSE
|
||||||
|
|
||||||
_node_data: AnswerNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
|
segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer)
|
||||||
files = self._extract_files_from_segments(segments.value)
|
files = self._extract_files_from_segments(segments.value)
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
|
@ -71,4 +69,4 @@ class AnswerNode(Node[AnswerNodeData]):
|
||||||
Returns:
|
Returns:
|
||||||
Template instance for this Answer node
|
Template instance for this Answer node
|
||||||
"""
|
"""
|
||||||
return Template.from_answer_template(self._node_data.answer)
|
return Template.from_answer_template(self.node_data.answer)
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,6 @@ from .exc import (
|
||||||
class CodeNode(Node[CodeNodeData]):
|
class CodeNode(Node[CodeNodeData]):
|
||||||
node_type = NodeType.CODE
|
node_type = NodeType.CODE
|
||||||
|
|
||||||
_node_data: CodeNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -48,12 +46,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# Get code language
|
# Get code language
|
||||||
code_language = self._node_data.code_language
|
code_language = self.node_data.code_language
|
||||||
code = self._node_data.code
|
code = self.node_data.code
|
||||||
|
|
||||||
# Get variables
|
# Get variables
|
||||||
variables = {}
|
variables = {}
|
||||||
for variable_selector in self._node_data.variables:
|
for variable_selector in self.node_data.variables:
|
||||||
variable_name = variable_selector.variable
|
variable_name = variable_selector.variable
|
||||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
if isinstance(variable, ArrayFileSegment):
|
if isinstance(variable, ArrayFileSegment):
|
||||||
|
|
@ -69,7 +67,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Transform result
|
# Transform result
|
||||||
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
|
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||||
except (CodeExecutionError, CodeNodeError) as e:
|
except (CodeExecutionError, CodeNodeError) as e:
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||||
|
|
@ -406,7 +404,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return self._node_data.retry_config.retry_enabled
|
return self.node_data.retry_config.retry_enabled
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||||
Datasource Node
|
Datasource Node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_node_data: DatasourceNodeData
|
|
||||||
node_type = NodeType.DATASOURCE
|
node_type = NodeType.DATASOURCE
|
||||||
execution_type = NodeExecutionType.ROOT
|
execution_type = NodeExecutionType.ROOT
|
||||||
|
|
||||||
|
|
@ -51,7 +50,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||||
Run the datasource node
|
Run the datasource node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_data = self._node_data
|
node_data = self.node_data
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||||
if not datasource_type_segement:
|
if not datasource_type_segement:
|
||||||
|
|
|
||||||
|
|
@ -43,14 +43,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||||
|
|
||||||
node_type = NodeType.DOCUMENT_EXTRACTOR
|
node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||||
|
|
||||||
_node_data: DocumentExtractorNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
variable_selector = self._node_data.variable_selector
|
variable_selector = self.node_data.variable_selector
|
||||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||||
|
|
||||||
if variable is None:
|
if variable is None:
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,6 @@ class EndNode(Node[EndNodeData]):
|
||||||
node_type = NodeType.END
|
node_type = NodeType.END
|
||||||
execution_type = NodeExecutionType.RESPONSE
|
execution_type = NodeExecutionType.RESPONSE
|
||||||
|
|
||||||
_node_data: EndNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
@ -22,7 +20,7 @@ class EndNode(Node[EndNodeData]):
|
||||||
This method runs after streaming is complete (if streaming was enabled).
|
This method runs after streaming is complete (if streaming was enabled).
|
||||||
It collects all output variables and returns them.
|
It collects all output variables and returns them.
|
||||||
"""
|
"""
|
||||||
output_variables = self._node_data.outputs
|
output_variables = self.node_data.outputs
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for variable_selector in output_variables:
|
for variable_selector in output_variables:
|
||||||
|
|
@ -44,6 +42,6 @@ class EndNode(Node[EndNodeData]):
|
||||||
Template instance for this End node
|
Template instance for this End node
|
||||||
"""
|
"""
|
||||||
outputs_config = [
|
outputs_config = [
|
||||||
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
|
{"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs
|
||||||
]
|
]
|
||||||
return Template.from_end_outputs(outputs_config)
|
return Template.from_end_outputs(outputs_config)
|
||||||
|
|
|
||||||
|
|
@ -34,8 +34,6 @@ logger = logging.getLogger(__name__)
|
||||||
class HttpRequestNode(Node[HttpRequestNodeData]):
|
class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||||
node_type = NodeType.HTTP_REQUEST
|
node_type = NodeType.HTTP_REQUEST
|
||||||
|
|
||||||
_node_data: HttpRequestNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
|
|
@ -69,8 +67,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||||
process_data = {}
|
process_data = {}
|
||||||
try:
|
try:
|
||||||
http_executor = Executor(
|
http_executor = Executor(
|
||||||
node_data=self._node_data,
|
node_data=self.node_data,
|
||||||
timeout=self._get_request_timeout(self._node_data),
|
timeout=self._get_request_timeout(self.node_data),
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
)
|
)
|
||||||
|
|
@ -225,4 +223,4 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return self._node_data.retry_config.retry_enabled
|
return self.node_data.retry_config.retry_enabled
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,6 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||||
"handle",
|
"handle",
|
||||||
)
|
)
|
||||||
|
|
||||||
_node_data: HumanInputNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
@ -49,12 +47,12 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||||
def _is_completion_ready(self) -> bool:
|
def _is_completion_ready(self) -> bool:
|
||||||
"""Determine whether all required inputs are satisfied."""
|
"""Determine whether all required inputs are satisfied."""
|
||||||
|
|
||||||
if not self._node_data.required_variables:
|
if not self.node_data.required_variables:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
for selector_str in self._node_data.required_variables:
|
for selector_str in self.node_data.required_variables:
|
||||||
parts = selector_str.split(".")
|
parts = selector_str.split(".")
|
||||||
if len(parts) != 2:
|
if len(parts) != 2:
|
||||||
return False
|
return False
|
||||||
|
|
@ -74,7 +72,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||||
if handle:
|
if handle:
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
default_values = self._node_data.default_value_dict
|
default_values = self.node_data.default_value_dict
|
||||||
for key in self._BRANCH_SELECTION_KEYS:
|
for key in self._BRANCH_SELECTION_KEYS:
|
||||||
handle = self._normalize_branch_value(default_values.get(key))
|
handle = self._normalize_branch_value(default_values.get(key))
|
||||||
if handle:
|
if handle:
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,6 @@ class IfElseNode(Node[IfElseNodeData]):
|
||||||
node_type = NodeType.IF_ELSE
|
node_type = NodeType.IF_ELSE
|
||||||
execution_type = NodeExecutionType.BRANCH
|
execution_type = NodeExecutionType.BRANCH
|
||||||
|
|
||||||
_node_data: IfElseNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
@ -37,8 +35,8 @@ class IfElseNode(Node[IfElseNodeData]):
|
||||||
condition_processor = ConditionProcessor()
|
condition_processor = ConditionProcessor()
|
||||||
try:
|
try:
|
||||||
# Check if the new cases structure is used
|
# Check if the new cases structure is used
|
||||||
if self._node_data.cases:
|
if self.node_data.cases:
|
||||||
for case in self._node_data.cases:
|
for case in self.node_data.cases:
|
||||||
input_conditions, group_result, final_result = condition_processor.process_conditions(
|
input_conditions, group_result, final_result = condition_processor.process_conditions(
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
conditions=case.conditions,
|
conditions=case.conditions,
|
||||||
|
|
@ -64,8 +62,8 @@ class IfElseNode(Node[IfElseNodeData]):
|
||||||
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
|
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
|
||||||
condition_processor=condition_processor,
|
condition_processor=condition_processor,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
conditions=self._node_data.conditions or [],
|
conditions=self.node_data.conditions or [],
|
||||||
operator=self._node_data.logical_operator or "and",
|
operator=self.node_data.logical_operator or "and",
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_case_id = "true" if final_result else "false"
|
selected_case_id = "true" if final_result else "false"
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
|
|
||||||
node_type = NodeType.ITERATION
|
node_type = NodeType.ITERATION
|
||||||
execution_type = NodeExecutionType.CONTAINER
|
execution_type = NodeExecutionType.CONTAINER
|
||||||
_node_data: IterationNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
|
|
@ -136,10 +135,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||||
|
|
||||||
if not variable:
|
if not variable:
|
||||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
||||||
|
|
||||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||||
|
|
@ -174,7 +173,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
return cast(list[object], iterator_list_value)
|
return cast(list[object], iterator_list_value)
|
||||||
|
|
||||||
def _validate_start_node(self) -> None:
|
def _validate_start_node(self) -> None:
|
||||||
if not self._node_data.start_node_id:
|
if not self.node_data.start_node_id:
|
||||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||||
|
|
||||||
def _execute_iterations(
|
def _execute_iterations(
|
||||||
|
|
@ -184,7 +183,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
iter_run_map: dict[str, float],
|
iter_run_map: dict[str, float],
|
||||||
usage_accumulator: list[LLMUsage],
|
usage_accumulator: list[LLMUsage],
|
||||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||||
if self._node_data.is_parallel:
|
if self.node_data.is_parallel:
|
||||||
# Parallel mode execution
|
# Parallel mode execution
|
||||||
yield from self._execute_parallel_iterations(
|
yield from self._execute_parallel_iterations(
|
||||||
iterator_list_value=iterator_list_value,
|
iterator_list_value=iterator_list_value,
|
||||||
|
|
@ -231,7 +230,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
outputs.extend([None] * len(iterator_list_value))
|
outputs.extend([None] * len(iterator_list_value))
|
||||||
|
|
||||||
# Determine the number of parallel workers
|
# Determine the number of parallel workers
|
||||||
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
# Submit all iteration tasks
|
# Submit all iteration tasks
|
||||||
|
|
@ -287,7 +286,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle errors based on error_handle_mode
|
# Handle errors based on error_handle_mode
|
||||||
match self._node_data.error_handle_mode:
|
match self.node_data.error_handle_mode:
|
||||||
case ErrorHandleMode.TERMINATED:
|
case ErrorHandleMode.TERMINATED:
|
||||||
# Cancel remaining futures and re-raise
|
# Cancel remaining futures and re-raise
|
||||||
for f in future_to_index:
|
for f in future_to_index:
|
||||||
|
|
@ -300,7 +299,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
outputs[index] = None # Will be filtered later
|
outputs[index] = None # Will be filtered later
|
||||||
|
|
||||||
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||||
outputs[:] = [output for output in outputs if output is not None]
|
outputs[:] = [output for output in outputs if output is not None]
|
||||||
|
|
||||||
def _execute_single_iteration_parallel(
|
def _execute_single_iteration_parallel(
|
||||||
|
|
@ -389,7 +388,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
If flatten_output is True (default), flattens the list if all elements are lists.
|
If flatten_output is True (default), flattens the list if all elements are lists.
|
||||||
"""
|
"""
|
||||||
# If flatten_output is disabled, return outputs as-is
|
# If flatten_output is disabled, return outputs as-is
|
||||||
if not self._node_data.flatten_output:
|
if not self.node_data.flatten_output:
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
if not outputs:
|
if not outputs:
|
||||||
|
|
@ -569,14 +568,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||||
yield event
|
yield event
|
||||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||||
result = variable_pool.get(self._node_data.output_selector)
|
result = variable_pool.get(self.node_data.output_selector)
|
||||||
if result is None:
|
if result is None:
|
||||||
outputs.append(None)
|
outputs.append(None)
|
||||||
else:
|
else:
|
||||||
outputs.append(result.to_object())
|
outputs.append(result.to_object())
|
||||||
return
|
return
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
match self._node_data.error_handle_mode:
|
match self.node_data.error_handle_mode:
|
||||||
case ErrorHandleMode.TERMINATED:
|
case ErrorHandleMode.TERMINATED:
|
||||||
raise IterationNodeError(event.error)
|
raise IterationNodeError(event.error)
|
||||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||||
|
|
@ -627,7 +626,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
|
|
||||||
# Initialize the iteration graph with the new node factory
|
# Initialize the iteration graph with the new node factory
|
||||||
iteration_graph = Graph.init(
|
iteration_graph = Graph.init(
|
||||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
|
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not iteration_graph:
|
if not iteration_graph:
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,6 @@ class IterationStartNode(Node[IterationStartNodeData]):
|
||||||
|
|
||||||
node_type = NodeType.ITERATION_START
|
node_type = NodeType.ITERATION_START
|
||||||
|
|
||||||
_node_data: IterationStartNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -35,12 +35,11 @@ default_retrieval_model = {
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||||
_node_data: KnowledgeIndexNodeData
|
|
||||||
node_type = NodeType.KNOWLEDGE_INDEX
|
node_type = NodeType.KNOWLEDGE_INDEX
|
||||||
execution_type = NodeExecutionType.RESPONSE
|
execution_type = NodeExecutionType.RESPONSE
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult: # type: ignore
|
def _run(self) -> NodeRunResult: # type: ignore
|
||||||
node_data = self._node_data
|
node_data = self.node_data
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||||
if not dataset_id:
|
if not dataset_id:
|
||||||
|
|
|
||||||
|
|
@ -83,8 +83,6 @@ default_retrieval_model = {
|
||||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||||
|
|
||||||
_node_data: KnowledgeRetrievalNodeData
|
|
||||||
|
|
||||||
# Instance attributes specific to LLMNode.
|
# Instance attributes specific to LLMNode.
|
||||||
# Output variable for file
|
# Output variable for file
|
||||||
_file_outputs: list["File"]
|
_file_outputs: list["File"]
|
||||||
|
|
@ -122,7 +120,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# extract variables
|
# extract variables
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
||||||
if not isinstance(variable, StringSegment):
|
if not isinstance(variable, StringSegment):
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
|
@ -163,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
# retrieve knowledge
|
# retrieve knowledge
|
||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
try:
|
try:
|
||||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
||||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
|
@ -536,7 +534,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||||
structured_output=None,
|
structured_output=None,
|
||||||
file_saver=self._llm_file_saver,
|
file_saver=self._llm_file_saver,
|
||||||
file_outputs=self._file_outputs,
|
file_outputs=self._file_outputs,
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,6 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||||
class ListOperatorNode(Node[ListOperatorNodeData]):
|
class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||||
node_type = NodeType.LIST_OPERATOR
|
node_type = NodeType.LIST_OPERATOR
|
||||||
|
|
||||||
_node_data: ListOperatorNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
@ -48,9 +46,9 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||||
process_data: dict[str, Sequence[object]] = {}
|
process_data: dict[str, Sequence[object]] = {}
|
||||||
outputs: dict[str, Any] = {}
|
outputs: dict[str, Any] = {}
|
||||||
|
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
||||||
if variable is None:
|
if variable is None:
|
||||||
error_message = f"Variable not found for selector: {self._node_data.variable}"
|
error_message = f"Variable not found for selector: {self.node_data.variable}"
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||||
)
|
)
|
||||||
|
|
@ -69,7 +67,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||||
)
|
)
|
||||||
|
|
@ -83,19 +81,19 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Filter
|
# Filter
|
||||||
if self._node_data.filter_by.enabled:
|
if self.node_data.filter_by.enabled:
|
||||||
variable = self._apply_filter(variable)
|
variable = self._apply_filter(variable)
|
||||||
|
|
||||||
# Extract
|
# Extract
|
||||||
if self._node_data.extract_by.enabled:
|
if self.node_data.extract_by.enabled:
|
||||||
variable = self._extract_slice(variable)
|
variable = self._extract_slice(variable)
|
||||||
|
|
||||||
# Order
|
# Order
|
||||||
if self._node_data.order_by.enabled:
|
if self.node_data.order_by.enabled:
|
||||||
variable = self._apply_order(variable)
|
variable = self._apply_order(variable)
|
||||||
|
|
||||||
# Slice
|
# Slice
|
||||||
if self._node_data.limit.enabled:
|
if self.node_data.limit.enabled:
|
||||||
variable = self._apply_slice(variable)
|
variable = self._apply_slice(variable)
|
||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
|
|
@ -121,7 +119,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||||
filter_func: Callable[[Any], bool]
|
filter_func: Callable[[Any], bool]
|
||||||
result: list[Any] = []
|
result: list[Any] = []
|
||||||
for condition in self._node_data.filter_by.conditions:
|
for condition in self.node_data.filter_by.conditions:
|
||||||
if isinstance(variable, ArrayStringSegment):
|
if isinstance(variable, ArrayStringSegment):
|
||||||
if not isinstance(condition.value, str):
|
if not isinstance(condition.value, str):
|
||||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||||
|
|
@ -160,22 +158,22 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||||
|
|
||||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||||
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
|
result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC)
|
||||||
variable = variable.model_copy(update={"value": result})
|
variable = variable.model_copy(update={"value": result})
|
||||||
else:
|
else:
|
||||||
result = _order_file(
|
result = _order_file(
|
||||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||||
)
|
)
|
||||||
variable = variable.model_copy(update={"value": result})
|
variable = variable.model_copy(update={"value": result})
|
||||||
|
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||||
result = variable.value[: self._node_data.limit.size]
|
result = variable.value[: self.node_data.limit.size]
|
||||||
return variable.model_copy(update={"value": result})
|
return variable.model_copy(update={"value": result})
|
||||||
|
|
||||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
|
||||||
if value < 1:
|
if value < 1:
|
||||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||||
if value > len(variable.value):
|
if value > len(variable.value):
|
||||||
|
|
|
||||||
|
|
@ -102,8 +102,6 @@ logger = logging.getLogger(__name__)
|
||||||
class LLMNode(Node[LLMNodeData]):
|
class LLMNode(Node[LLMNodeData]):
|
||||||
node_type = NodeType.LLM
|
node_type = NodeType.LLM
|
||||||
|
|
||||||
_node_data: LLMNodeData
|
|
||||||
|
|
||||||
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
||||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||||
|
|
||||||
|
|
@ -154,13 +152,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# init messages template
|
# init messages template
|
||||||
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
|
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||||
|
|
||||||
# fetch variables and fetch values from variable pool
|
# fetch variables and fetch values from variable pool
|
||||||
inputs = self._fetch_inputs(node_data=self._node_data)
|
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||||
|
|
||||||
# fetch jinja2 inputs
|
# fetch jinja2 inputs
|
||||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
|
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
||||||
|
|
||||||
# merge inputs
|
# merge inputs
|
||||||
inputs.update(jinja_inputs)
|
inputs.update(jinja_inputs)
|
||||||
|
|
@ -169,9 +167,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
files = (
|
files = (
|
||||||
llm_utils.fetch_files(
|
llm_utils.fetch_files(
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
selector=self._node_data.vision.configs.variable_selector,
|
selector=self.node_data.vision.configs.variable_selector,
|
||||||
)
|
)
|
||||||
if self._node_data.vision.enabled
|
if self.node_data.vision.enabled
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -179,7 +177,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||||
|
|
||||||
# fetch context value
|
# fetch context value
|
||||||
generator = self._fetch_context(node_data=self._node_data)
|
generator = self._fetch_context(node_data=self.node_data)
|
||||||
context = None
|
context = None
|
||||||
for event in generator:
|
for event in generator:
|
||||||
context = event.context
|
context = event.context
|
||||||
|
|
@ -189,7 +187,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
|
|
||||||
# fetch model config
|
# fetch model config
|
||||||
model_instance, model_config = LLMNode._fetch_model_config(
|
model_instance, model_config = LLMNode._fetch_model_config(
|
||||||
node_data_model=self._node_data.model,
|
node_data_model=self.node_data.model,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -197,13 +195,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
memory = llm_utils.fetch_memory(
|
memory = llm_utils.fetch_memory(
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
app_id=self.app_id,
|
app_id=self.app_id,
|
||||||
node_data_memory=self._node_data.memory,
|
node_data_memory=self.node_data.memory,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
)
|
)
|
||||||
|
|
||||||
query: str | None = None
|
query: str | None = None
|
||||||
if self._node_data.memory:
|
if self.node_data.memory:
|
||||||
query = self._node_data.memory.query_prompt_template
|
query = self.node_data.memory.query_prompt_template
|
||||||
if not query and (
|
if not query and (
|
||||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||||
):
|
):
|
||||||
|
|
@ -215,29 +213,29 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
prompt_template=self._node_data.prompt_template,
|
prompt_template=self.node_data.prompt_template,
|
||||||
memory_config=self._node_data.memory,
|
memory_config=self.node_data.memory,
|
||||||
vision_enabled=self._node_data.vision.enabled,
|
vision_enabled=self.node_data.vision.enabled,
|
||||||
vision_detail=self._node_data.vision.configs.detail,
|
vision_detail=self.node_data.vision.configs.detail,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
generator = LLMNode.invoke_llm(
|
generator = LLMNode.invoke_llm(
|
||||||
node_data_model=self._node_data.model,
|
node_data_model=self.node_data.model,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||||
structured_output=self._node_data.structured_output,
|
structured_output=self.node_data.structured_output,
|
||||||
file_saver=self._llm_file_saver,
|
file_saver=self._llm_file_saver,
|
||||||
file_outputs=self._file_outputs,
|
file_outputs=self._file_outputs,
|
||||||
node_id=self._node_id,
|
node_id=self._node_id,
|
||||||
node_type=self.node_type,
|
node_type=self.node_type,
|
||||||
reasoning_format=self._node_data.reasoning_format,
|
reasoning_format=self.node_data.reasoning_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
structured_output: LLMStructuredOutput | None = None
|
structured_output: LLMStructuredOutput | None = None
|
||||||
|
|
@ -253,12 +251,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
reasoning_content = event.reasoning_content or ""
|
reasoning_content = event.reasoning_content or ""
|
||||||
|
|
||||||
# For downstream nodes, determine clean text based on reasoning_format
|
# For downstream nodes, determine clean text based on reasoning_format
|
||||||
if self._node_data.reasoning_format == "tagged":
|
if self.node_data.reasoning_format == "tagged":
|
||||||
# Keep <think> tags for backward compatibility
|
# Keep <think> tags for backward compatibility
|
||||||
clean_text = result_text
|
clean_text = result_text
|
||||||
else:
|
else:
|
||||||
# Extract clean text from <think> tags
|
# Extract clean text from <think> tags
|
||||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
|
||||||
|
|
||||||
# Process structured output if available from the event.
|
# Process structured output if available from the event.
|
||||||
structured_output = (
|
structured_output = (
|
||||||
|
|
@ -1204,7 +1202,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return self._node_data.retry_config.retry_enabled
|
return self.node_data.retry_config.retry_enabled
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(
|
def _combine_message_content_with_role(
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,6 @@ class LoopEndNode(Node[LoopEndNodeData]):
|
||||||
|
|
||||||
node_type = NodeType.LOOP_END
|
node_type = NodeType.LOOP_END
|
||||||
|
|
||||||
_node_data: LoopEndNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_type = NodeType.LOOP
|
node_type = NodeType.LOOP
|
||||||
_node_data: LoopNodeData
|
|
||||||
execution_type = NodeExecutionType.CONTAINER
|
execution_type = NodeExecutionType.CONTAINER
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -56,27 +55,27 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Run the node."""
|
"""Run the node."""
|
||||||
# Get inputs
|
# Get inputs
|
||||||
loop_count = self._node_data.loop_count
|
loop_count = self.node_data.loop_count
|
||||||
break_conditions = self._node_data.break_conditions
|
break_conditions = self.node_data.break_conditions
|
||||||
logical_operator = self._node_data.logical_operator
|
logical_operator = self.node_data.logical_operator
|
||||||
|
|
||||||
inputs = {"loop_count": loop_count}
|
inputs = {"loop_count": loop_count}
|
||||||
|
|
||||||
if not self._node_data.start_node_id:
|
if not self.node_data.start_node_id:
|
||||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||||
|
|
||||||
root_node_id = self._node_data.start_node_id
|
root_node_id = self.node_data.start_node_id
|
||||||
|
|
||||||
# Initialize loop variables in the original variable pool
|
# Initialize loop variables in the original variable pool
|
||||||
loop_variable_selectors = {}
|
loop_variable_selectors = {}
|
||||||
if self._node_data.loop_variables:
|
if self.node_data.loop_variables:
|
||||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||||
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
||||||
if isinstance(var.value, list)
|
if isinstance(var.value, list)
|
||||||
else None,
|
else None,
|
||||||
}
|
}
|
||||||
for loop_variable in self._node_data.loop_variables:
|
for loop_variable in self.node_data.loop_variables:
|
||||||
if loop_variable.value_type not in value_processor:
|
if loop_variable.value_type not in value_processor:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||||
|
|
@ -164,7 +163,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
|
|
||||||
yield LoopNextEvent(
|
yield LoopNextEvent(
|
||||||
index=i + 1,
|
index=i + 1,
|
||||||
pre_loop_output=self._node_data.outputs,
|
pre_loop_output=self.node_data.outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._accumulate_usage(loop_usage)
|
self._accumulate_usage(loop_usage)
|
||||||
|
|
@ -172,7 +171,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
yield LoopSucceededEvent(
|
yield LoopSucceededEvent(
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
outputs=self._node_data.outputs,
|
outputs=self.node_data.outputs,
|
||||||
steps=loop_count,
|
steps=loop_count,
|
||||||
metadata={
|
metadata={
|
||||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||||
|
|
@ -194,7 +193,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||||
},
|
},
|
||||||
outputs=self._node_data.outputs,
|
outputs=self.node_data.outputs,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
llm_usage=loop_usage,
|
llm_usage=loop_usage,
|
||||||
)
|
)
|
||||||
|
|
@ -252,11 +251,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
if isinstance(event, GraphRunFailedEvent):
|
if isinstance(event, GraphRunFailedEvent):
|
||||||
raise Exception(event.error)
|
raise Exception(event.error)
|
||||||
|
|
||||||
for loop_var in self._node_data.loop_variables or []:
|
for loop_var in self.node_data.loop_variables or []:
|
||||||
key, sel = loop_var.label, [self._node_id, loop_var.label]
|
key, sel = loop_var.label, [self._node_id, loop_var.label]
|
||||||
segment = self.graph_runtime_state.variable_pool.get(sel)
|
segment = self.graph_runtime_state.variable_pool.get(sel)
|
||||||
self._node_data.outputs[key] = segment.value if segment else None
|
self.node_data.outputs[key] = segment.value if segment else None
|
||||||
self._node_data.outputs["loop_round"] = current_index + 1
|
self.node_data.outputs["loop_round"] = current_index + 1
|
||||||
|
|
||||||
return reach_break_node
|
return reach_break_node
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,6 @@ class LoopStartNode(Node[LoopStartNodeData]):
|
||||||
|
|
||||||
node_type = NodeType.LOOP_START
|
node_type = NodeType.LOOP_START
|
||||||
|
|
||||||
_node_data: LoopStartNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -90,8 +90,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||||
|
|
||||||
node_type = NodeType.PARAMETER_EXTRACTOR
|
node_type = NodeType.PARAMETER_EXTRACTOR
|
||||||
|
|
||||||
_node_data: ParameterExtractorNodeData
|
|
||||||
|
|
||||||
_model_instance: ModelInstance | None = None
|
_model_instance: ModelInstance | None = None
|
||||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||||
|
|
||||||
|
|
@ -116,7 +114,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||||
"""
|
"""
|
||||||
Run the node.
|
Run the node.
|
||||||
"""
|
"""
|
||||||
node_data = self._node_data
|
node_data = self.node_data
|
||||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||||
query = variable.text if variable else ""
|
query = variable.text if variable else ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
node_type = NodeType.QUESTION_CLASSIFIER
|
node_type = NodeType.QUESTION_CLASSIFIER
|
||||||
execution_type = NodeExecutionType.BRANCH
|
execution_type = NodeExecutionType.BRANCH
|
||||||
|
|
||||||
_node_data: QuestionClassifierNodeData
|
|
||||||
|
|
||||||
_file_outputs: list["File"]
|
_file_outputs: list["File"]
|
||||||
_llm_file_saver: LLMFileSaver
|
_llm_file_saver: LLMFileSaver
|
||||||
|
|
||||||
|
|
@ -82,7 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
node_data = self._node_data
|
node_data = self.node_data
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
# extract variables
|
# extract variables
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,6 @@ class StartNode(Node[StartNodeData]):
|
||||||
node_type = NodeType.START
|
node_type = NodeType.START
|
||||||
execution_type = NodeExecutionType.ROOT
|
execution_type = NodeExecutionType.ROOT
|
||||||
|
|
||||||
_node_data: StartNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,6 @@ MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||||
|
|
||||||
_node_data: TemplateTransformNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -35,14 +33,14 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
# Get variables
|
# Get variables
|
||||||
variables: dict[str, Any] = {}
|
variables: dict[str, Any] = {}
|
||||||
for variable_selector in self._node_data.variables:
|
for variable_selector in self.node_data.variables:
|
||||||
variable_name = variable_selector.variable
|
variable_name = variable_selector.variable
|
||||||
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||||
variables[variable_name] = value.to_object() if value else None
|
variables[variable_name] = value.to_object() if value else None
|
||||||
# Run code
|
# Run code
|
||||||
try:
|
try:
|
||||||
result = CodeExecutor.execute_workflow_code_template(
|
result = CodeExecutor.execute_workflow_code_template(
|
||||||
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
|
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||||
)
|
)
|
||||||
except CodeExecutionError as e:
|
except CodeExecutionError as e:
|
||||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
|
|
||||||
node_type = NodeType.TOOL
|
node_type = NodeType.TOOL
|
||||||
|
|
||||||
_node_data: ToolNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
@ -59,13 +57,11 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
"""
|
"""
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||||
|
|
||||||
node_data = self._node_data
|
|
||||||
|
|
||||||
# fetch tool icon
|
# fetch tool icon
|
||||||
tool_info = {
|
tool_info = {
|
||||||
"provider_type": node_data.provider_type.value,
|
"provider_type": self.node_data.provider_type.value,
|
||||||
"provider_id": node_data.provider_id,
|
"provider_id": self.node_data.provider_id,
|
||||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
# get tool runtime
|
# get tool runtime
|
||||||
|
|
@ -77,10 +73,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
# But for backward compatibility with historical data
|
# But for backward compatibility with historical data
|
||||||
# this version field judgment is still preserved here.
|
# this version field judgment is still preserved here.
|
||||||
variable_pool: VariablePool | None = None
|
variable_pool: VariablePool | None = None
|
||||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||||
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
|
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
|
||||||
)
|
)
|
||||||
except ToolNodeError as e:
|
except ToolNodeError as e:
|
||||||
yield StreamCompletedEvent(
|
yield StreamCompletedEvent(
|
||||||
|
|
@ -99,12 +95,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
parameters = self._generate_parameters(
|
parameters = self._generate_parameters(
|
||||||
tool_parameters=tool_parameters,
|
tool_parameters=tool_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self._node_data,
|
node_data=self.node_data,
|
||||||
)
|
)
|
||||||
parameters_for_log = self._generate_parameters(
|
parameters_for_log = self._generate_parameters(
|
||||||
tool_parameters=tool_parameters,
|
tool_parameters=tool_parameters,
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
node_data=self._node_data,
|
node_data=self.node_data,
|
||||||
for_log=True,
|
for_log=True,
|
||||||
)
|
)
|
||||||
# get conversation id
|
# get conversation id
|
||||||
|
|
@ -149,7 +145,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||||
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
|
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -159,7 +155,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||||
error=e.to_user_friendly_error(plugin_name=node_data.provider_name),
|
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
|
||||||
error_type=type(e).__name__,
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -495,4 +491,4 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return self._node_data.retry_config.retry_enabled
|
return self.node_data.retry_config.retry_enabled
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||||
# Get trigger data passed when workflow was triggered
|
# Get trigger data passed when workflow was triggered
|
||||||
metadata = {
|
metadata = {
|
||||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||||
"provider_id": self._node_data.provider_id,
|
"provider_id": self.node_data.provider_id,
|
||||||
"event_name": self._node_data.event_name,
|
"event_name": self.node_data.event_name,
|
||||||
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
|
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||||
webhook_headers = webhook_data.get("headers", {})
|
webhook_headers = webhook_data.get("headers", {})
|
||||||
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
|
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
|
||||||
|
|
||||||
for header in self._node_data.headers:
|
for header in self.node_data.headers:
|
||||||
header_name = header.name
|
header_name = header.name
|
||||||
value = _get_normalized(webhook_headers, header_name)
|
value = _get_normalized(webhook_headers, header_name)
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|
@ -93,20 +93,20 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||||
outputs[sanitized_name] = value
|
outputs[sanitized_name] = value
|
||||||
|
|
||||||
# Extract configured query parameters
|
# Extract configured query parameters
|
||||||
for param in self._node_data.params:
|
for param in self.node_data.params:
|
||||||
param_name = param.name
|
param_name = param.name
|
||||||
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
|
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
|
||||||
|
|
||||||
# Extract configured body parameters
|
# Extract configured body parameters
|
||||||
for body_param in self._node_data.body:
|
for body_param in self.node_data.body:
|
||||||
param_name = body_param.name
|
param_name = body_param.name
|
||||||
param_type = body_param.type
|
param_type = body_param.type
|
||||||
|
|
||||||
if self._node_data.content_type == ContentType.TEXT:
|
if self.node_data.content_type == ContentType.TEXT:
|
||||||
# For text/plain, the entire body is a single string parameter
|
# For text/plain, the entire body is a single string parameter
|
||||||
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
|
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
|
||||||
continue
|
continue
|
||||||
elif self._node_data.content_type == ContentType.BINARY:
|
elif self.node_data.content_type == ContentType.BINARY:
|
||||||
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
|
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,6 @@ from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorN
|
||||||
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
||||||
node_type = NodeType.VARIABLE_AGGREGATOR
|
node_type = NodeType.VARIABLE_AGGREGATOR
|
||||||
|
|
||||||
_node_data: VariableAggregatorNodeData
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
@ -21,8 +19,8 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
||||||
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
|
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
|
||||||
inputs = {}
|
inputs = {}
|
||||||
|
|
||||||
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
|
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||||
for selector in self._node_data.variables:
|
for selector in self.node_data.variables:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||||
if variable is not None:
|
if variable is not None:
|
||||||
outputs = {"output": variable}
|
outputs = {"output": variable}
|
||||||
|
|
@ -30,7 +28,7 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
||||||
inputs = {".".join(selector[1:]): variable.to_object()}
|
inputs = {".".join(selector[1:]): variable.to_object()}
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for group in self._node_data.advanced_settings.groups:
|
for group in self.node_data.advanced_settings.groups:
|
||||||
for selector in group.variables:
|
for selector in group.variables:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||||
node_type = NodeType.VARIABLE_ASSIGNER
|
node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||||
|
|
||||||
_node_data: VariableAssignerData
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
|
|
@ -71,21 +69,21 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
assigned_variable_selector = self._node_data.assigned_variable_selector
|
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||||
if not isinstance(original_variable, Variable):
|
if not isinstance(original_variable, Variable):
|
||||||
raise VariableOperatorNodeError("assigned variable not found")
|
raise VariableOperatorNodeError("assigned variable not found")
|
||||||
|
|
||||||
match self._node_data.write_mode:
|
match self.node_data.write_mode:
|
||||||
case WriteMode.OVER_WRITE:
|
case WriteMode.OVER_WRITE:
|
||||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||||
if not income_value:
|
if not income_value:
|
||||||
raise VariableOperatorNodeError("input value not found")
|
raise VariableOperatorNodeError("input value not found")
|
||||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||||
|
|
||||||
case WriteMode.APPEND:
|
case WriteMode.APPEND:
|
||||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||||
if not income_value:
|
if not income_value:
|
||||||
raise VariableOperatorNodeError("input value not found")
|
raise VariableOperatorNodeError("input value not found")
|
||||||
updated_value = original_variable.value + [income_value.value]
|
updated_value = original_variable.value + [income_value.value]
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,6 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||||
node_type = NodeType.VARIABLE_ASSIGNER
|
node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
|
|
||||||
_node_data: VariableAssignerNodeData
|
|
||||||
|
|
||||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if this Variable Assigner node blocks the output of specific variables.
|
Check if this Variable Assigner node blocks the output of specific variables.
|
||||||
|
|
@ -62,7 +60,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||||
Returns True if this node updates any of the requested conversation variables.
|
Returns True if this node updates any of the requested conversation variables.
|
||||||
"""
|
"""
|
||||||
# Check each item in this Variable Assigner node
|
# Check each item in this Variable Assigner node
|
||||||
for item in self._node_data.items:
|
for item in self.node_data.items:
|
||||||
# Convert the item's variable_selector to tuple for comparison
|
# Convert the item's variable_selector to tuple for comparison
|
||||||
item_selector_tuple = tuple(item.variable_selector)
|
item_selector_tuple = tuple(item.variable_selector)
|
||||||
|
|
||||||
|
|
@ -97,13 +95,13 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||||
return var_mapping
|
return var_mapping
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
inputs = self._node_data.model_dump()
|
inputs = self.node_data.model_dump()
|
||||||
process_data: dict[str, Any] = {}
|
process_data: dict[str, Any] = {}
|
||||||
# NOTE: This node has no outputs
|
# NOTE: This node has no outputs
|
||||||
updated_variable_selectors: list[Sequence[str]] = []
|
updated_variable_selectors: list[Sequence[str]] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for item in self._node_data.items:
|
for item in self.node_data.items:
|
||||||
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
||||||
|
|
||||||
# ==================== Validation Part
|
# ==================== Validation Part
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue