Refactor workflow nodes to use generic node_data (#28782)

This commit is contained in:
-LAN- 2025-11-27 20:46:56 +08:00 committed by GitHub
parent 002d8769b0
commit 8b761319f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 121 additions and 170 deletions

View File

@ -70,7 +70,6 @@ class AgentNode(Node[AgentNodeData]):
""" """
node_type = NodeType.AGENT node_type = NodeType.AGENT
_node_data: AgentNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
@ -82,8 +81,8 @@ class AgentNode(Node[AgentNodeData]):
try: try:
strategy = get_plugin_agent_strategy( strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name, agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self._node_data.agent_strategy_name, agent_strategy_name=self.node_data.agent_strategy_name,
) )
except Exception as e: except Exception as e:
yield StreamCompletedEvent( yield StreamCompletedEvent(
@ -101,13 +100,13 @@ class AgentNode(Node[AgentNodeData]):
parameters = self._generate_agent_parameters( parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters, agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data, node_data=self.node_data,
strategy=strategy, strategy=strategy,
) )
parameters_for_log = self._generate_agent_parameters( parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters, agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data, node_data=self.node_data,
for_log=True, for_log=True,
strategy=strategy, strategy=strategy,
) )
@ -140,7 +139,7 @@ class AgentNode(Node[AgentNodeData]):
messages=message_stream, messages=message_stream,
tool_info={ tool_info={
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": self._node_data.agent_strategy_name, "agent_strategy": self.node_data.agent_strategy_name,
}, },
parameters_for_log=parameters_for_log, parameters_for_log=parameters_for_log,
user_id=self.user_id, user_id=self.user_id,
@ -387,7 +386,7 @@ class AgentNode(Node[AgentNodeData]):
current_plugin = next( current_plugin = next(
plugin plugin
for plugin in plugins for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
) )
icon = current_plugin.declaration.icon icon = current_plugin.declaration.icon
except StopIteration: except StopIteration:

View File

@ -14,14 +14,12 @@ class AnswerNode(Node[AnswerNodeData]):
node_type = NodeType.ANSWER node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer) segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer)
files = self._extract_files_from_segments(segments.value) files = self._extract_files_from_segments(segments.value)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -71,4 +69,4 @@ class AnswerNode(Node[AnswerNodeData]):
Returns: Returns:
Template instance for this Answer node Template instance for this Answer node
""" """
return Template.from_answer_template(self._node_data.answer) return Template.from_answer_template(self.node_data.answer)

View File

@ -24,8 +24,6 @@ from .exc import (
class CodeNode(Node[CodeNodeData]): class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE node_type = NodeType.CODE
_node_data: CodeNodeData
@classmethod @classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
""" """
@ -48,12 +46,12 @@ class CodeNode(Node[CodeNodeData]):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get code language # Get code language
code_language = self._node_data.code_language code_language = self.node_data.code_language
code = self._node_data.code code = self.node_data.code
# Get variables # Get variables
variables = {} variables = {}
for variable_selector in self._node_data.variables: for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment): if isinstance(variable, ArrayFileSegment):
@ -69,7 +67,7 @@ class CodeNode(Node[CodeNodeData]):
) )
# Transform result # Transform result
result = self._transform_result(result=result, output_schema=self._node_data.outputs) result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e: except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -406,7 +404,7 @@ class CodeNode(Node[CodeNodeData]):
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled return self.node_data.retry_config.retry_enabled
@staticmethod @staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None: def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:

View File

@ -42,7 +42,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
Datasource Node Datasource Node
""" """
_node_data: DatasourceNodeData
node_type = NodeType.DATASOURCE node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT execution_type = NodeExecutionType.ROOT
@ -51,7 +50,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
Run the datasource node Run the datasource node
""" """
node_data = self._node_data node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement: if not datasource_type_segement:

View File

@ -43,14 +43,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
node_type = NodeType.DOCUMENT_EXTRACTOR node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
def _run(self): def _run(self):
variable_selector = self._node_data.variable_selector variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None: if variable is None:

View File

@ -9,8 +9,6 @@ class EndNode(Node[EndNodeData]):
node_type = NodeType.END node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE execution_type = NodeExecutionType.RESPONSE
_node_data: EndNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -22,7 +20,7 @@ class EndNode(Node[EndNodeData]):
This method runs after streaming is complete (if streaming was enabled). This method runs after streaming is complete (if streaming was enabled).
It collects all output variables and returns them. It collects all output variables and returns them.
""" """
output_variables = self._node_data.outputs output_variables = self.node_data.outputs
outputs = {} outputs = {}
for variable_selector in output_variables: for variable_selector in output_variables:
@ -44,6 +42,6 @@ class EndNode(Node[EndNodeData]):
Template instance for this End node Template instance for this End node
""" """
outputs_config = [ outputs_config = [
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs {"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs
] ]
return Template.from_end_outputs(outputs_config) return Template.from_end_outputs(outputs_config)

View File

@ -34,8 +34,6 @@ logger = logging.getLogger(__name__)
class HttpRequestNode(Node[HttpRequestNodeData]): class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
@classmethod @classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return { return {
@ -69,8 +67,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
process_data = {} process_data = {}
try: try:
http_executor = Executor( http_executor = Executor(
node_data=self._node_data, node_data=self.node_data,
timeout=self._get_request_timeout(self._node_data), timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0, max_retries=0,
) )
@ -225,4 +223,4 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled return self.node_data.retry_config.retry_enabled

View File

@ -25,8 +25,6 @@ class HumanInputNode(Node[HumanInputNodeData]):
"handle", "handle",
) )
_node_data: HumanInputNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -49,12 +47,12 @@ class HumanInputNode(Node[HumanInputNodeData]):
def _is_completion_ready(self) -> bool: def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied.""" """Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables: if not self.node_data.required_variables:
return False return False
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables: for selector_str in self.node_data.required_variables:
parts = selector_str.split(".") parts = selector_str.split(".")
if len(parts) != 2: if len(parts) != 2:
return False return False
@ -74,7 +72,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
if handle: if handle:
return handle return handle
default_values = self._node_data.default_value_dict default_values = self.node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS: for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key)) handle = self._normalize_branch_value(default_values.get(key))
if handle: if handle:

View File

@ -16,8 +16,6 @@ class IfElseNode(Node[IfElseNodeData]):
node_type = NodeType.IF_ELSE node_type = NodeType.IF_ELSE
execution_type = NodeExecutionType.BRANCH execution_type = NodeExecutionType.BRANCH
_node_data: IfElseNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -37,8 +35,8 @@ class IfElseNode(Node[IfElseNodeData]):
condition_processor = ConditionProcessor() condition_processor = ConditionProcessor()
try: try:
# Check if the new cases structure is used # Check if the new cases structure is used
if self._node_data.cases: if self.node_data.cases:
for case in self._node_data.cases: for case in self.node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions( input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions, conditions=case.conditions,
@ -64,8 +62,8 @@ class IfElseNode(Node[IfElseNodeData]):
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
condition_processor=condition_processor, condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [], conditions=self.node_data.conditions or [],
operator=self._node_data.logical_operator or "and", operator=self.node_data.logical_operator or "and",
) )
selected_case_id = "true" if final_result else "false" selected_case_id = "true" if final_result else "false"

View File

@ -65,7 +65,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
node_type = NodeType.ITERATION node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
@classmethod @classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -136,10 +135,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
) )
def _get_iterator_variable(self) -> ArraySegment | NoneSegment: def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not variable: if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@ -174,7 +173,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return cast(list[object], iterator_list_value) return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None: def _validate_start_node(self) -> None:
if not self._node_data.start_node_id: if not self.node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
def _execute_iterations( def _execute_iterations(
@ -184,7 +183,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
iter_run_map: dict[str, float], iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage], usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel: if self.node_data.is_parallel:
# Parallel mode execution # Parallel mode execution
yield from self._execute_parallel_iterations( yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value, iterator_list_value=iterator_list_value,
@ -231,7 +230,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
outputs.extend([None] * len(iterator_list_value)) outputs.extend([None] * len(iterator_list_value))
# Determine the number of parallel workers # Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value)) max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor: with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks # Submit all iteration tasks
@ -287,7 +286,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
except Exception as e: except Exception as e:
# Handle errors based on error_handle_mode # Handle errors based on error_handle_mode
match self._node_data.error_handle_mode: match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED: case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise # Cancel remaining futures and re-raise
for f in future_to_index: for f in future_to_index:
@ -300,7 +299,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
outputs[index] = None # Will be filtered later outputs[index] = None # Will be filtered later
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None] outputs[:] = [output for output in outputs if output is not None]
def _execute_single_iteration_parallel( def _execute_single_iteration_parallel(
@ -389,7 +388,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
If flatten_output is True (default), flattens the list if all elements are lists. If flatten_output is True (default), flattens the list if all elements are lists.
""" """
# If flatten_output is disabled, return outputs as-is # If flatten_output is disabled, return outputs as-is
if not self._node_data.flatten_output: if not self.node_data.flatten_output:
return outputs return outputs
if not outputs: if not outputs:
@ -569,14 +568,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index) self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event yield event
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)): elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector) result = variable_pool.get(self.node_data.output_selector)
if result is None: if result is None:
outputs.append(None) outputs.append(None)
else: else:
outputs.append(result.to_object()) outputs.append(result.to_object())
return return
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode: match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED: case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error) raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR: case ErrorHandleMode.CONTINUE_ON_ERROR:
@ -627,7 +626,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Initialize the iteration graph with the new node factory # Initialize the iteration graph with the new node factory
iteration_graph = Graph.init( iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
) )
if not iteration_graph: if not iteration_graph:

View File

@ -11,8 +11,6 @@ class IterationStartNode(Node[IterationStartNodeData]):
node_type = NodeType.ITERATION_START node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

View File

@ -35,12 +35,11 @@ default_retrieval_model = {
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
_node_data: KnowledgeIndexNodeData
node_type = NodeType.KNOWLEDGE_INDEX node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE execution_type = NodeExecutionType.RESPONSE
def _run(self) -> NodeRunResult: # type: ignore def _run(self) -> NodeRunResult: # type: ignore
node_data = self._node_data node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id: if not dataset_id:

View File

@ -83,8 +83,6 @@ default_retrieval_model = {
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode. # Instance attributes specific to LLMNode.
# Output variable for file # Output variable for file
_file_outputs: list["File"] _file_outputs: list["File"]
@ -122,7 +120,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# extract variables # extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
if not isinstance(variable, StringSegment): if not isinstance(variable, StringSegment):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
@ -163,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# retrieve knowledge # retrieve knowledge
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
try: try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query) results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)} outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -536,7 +534,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id, user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled, structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=None, structured_output=None,
file_saver=self._llm_file_saver, file_saver=self._llm_file_saver,
file_outputs=self._file_outputs, file_outputs=self._file_outputs,

View File

@ -37,8 +37,6 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
class ListOperatorNode(Node[ListOperatorNodeData]): class ListOperatorNode(Node[ListOperatorNodeData]):
node_type = NodeType.LIST_OPERATOR node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -48,9 +46,9 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
process_data: dict[str, Sequence[object]] = {} process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {} outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None: if variable is None:
error_message = f"Variable not found for selector: {self._node_data.variable}" error_message = f"Variable not found for selector: {self.node_data.variable}"
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
) )
@ -69,7 +67,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
outputs=outputs, outputs=outputs,
) )
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}" error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
) )
@ -83,19 +81,19 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
try: try:
# Filter # Filter
if self._node_data.filter_by.enabled: if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable) variable = self._apply_filter(variable)
# Extract # Extract
if self._node_data.extract_by.enabled: if self.node_data.extract_by.enabled:
variable = self._extract_slice(variable) variable = self._extract_slice(variable)
# Order # Order
if self._node_data.order_by.enabled: if self.node_data.order_by.enabled:
variable = self._apply_order(variable) variable = self._apply_order(variable)
# Slice # Slice
if self._node_data.limit.enabled: if self.node_data.limit.enabled:
variable = self._apply_slice(variable) variable = self._apply_slice(variable)
outputs = { outputs = {
@ -121,7 +119,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool] filter_func: Callable[[Any], bool]
result: list[Any] = [] result: list[Any] = []
for condition in self._node_data.filter_by.conditions: for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment): if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str): if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@ -160,22 +158,22 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC) result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC)
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
else: else:
result = _order_file( result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
) )
variable = variable.model_copy(update={"value": result}) variable = variable.model_copy(update={"value": result})
return variable return variable
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size] result = variable.value[: self.node_data.limit.size]
return variable.model_copy(update={"value": result}) return variable.model_copy(update={"value": result})
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
if value < 1: if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}") raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
if value > len(variable.value): if value > len(variable.value):

View File

@ -102,8 +102,6 @@ logger = logging.getLogger(__name__)
class LLMNode(Node[LLMNodeData]): class LLMNode(Node[LLMNodeData]):
node_type = NodeType.LLM node_type = NodeType.LLM
_node_data: LLMNodeData
# Compiled regex for extracting <think> blocks (with compatibility for attributes) # Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL) _THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
@ -154,13 +152,13 @@ class LLMNode(Node[LLMNodeData]):
try: try:
# init messages template # init messages template
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template) self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool # fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self._node_data) inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs # fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data) jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs # merge inputs
inputs.update(jinja_inputs) inputs.update(jinja_inputs)
@ -169,9 +167,9 @@ class LLMNode(Node[LLMNodeData]):
files = ( files = (
llm_utils.fetch_files( llm_utils.fetch_files(
variable_pool=variable_pool, variable_pool=variable_pool,
selector=self._node_data.vision.configs.variable_selector, selector=self.node_data.vision.configs.variable_selector,
) )
if self._node_data.vision.enabled if self.node_data.vision.enabled
else [] else []
) )
@ -179,7 +177,7 @@ class LLMNode(Node[LLMNodeData]):
node_inputs["#files#"] = [file.to_dict() for file in files] node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value # fetch context value
generator = self._fetch_context(node_data=self._node_data) generator = self._fetch_context(node_data=self.node_data)
context = None context = None
for event in generator: for event in generator:
context = event.context context = event.context
@ -189,7 +187,7 @@ class LLMNode(Node[LLMNodeData]):
# fetch model config # fetch model config
model_instance, model_config = LLMNode._fetch_model_config( model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self._node_data.model, node_data_model=self.node_data.model,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
@ -197,13 +195,13 @@ class LLMNode(Node[LLMNodeData]):
memory = llm_utils.fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool, variable_pool=variable_pool,
app_id=self.app_id, app_id=self.app_id,
node_data_memory=self._node_data.memory, node_data_memory=self.node_data.memory,
model_instance=model_instance, model_instance=model_instance,
) )
query: str | None = None query: str | None = None
if self._node_data.memory: if self.node_data.memory:
query = self._node_data.memory.query_prompt_template query = self.node_data.memory.query_prompt_template
if not query and ( if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
): ):
@ -215,29 +213,29 @@ class LLMNode(Node[LLMNodeData]):
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
prompt_template=self._node_data.prompt_template, prompt_template=self.node_data.prompt_template,
memory_config=self._node_data.memory, memory_config=self.node_data.memory,
vision_enabled=self._node_data.vision.enabled, vision_enabled=self.node_data.vision.enabled,
vision_detail=self._node_data.vision.configs.detail, vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool, variable_pool=variable_pool,
jinja2_variables=self._node_data.prompt_config.jinja2_variables, jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
) )
# handle invoke result # handle invoke result
generator = LLMNode.invoke_llm( generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model, node_data_model=self.node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
stop=stop, stop=stop,
user_id=self.user_id, user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled, structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self._node_data.structured_output, structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver, file_saver=self._llm_file_saver,
file_outputs=self._file_outputs, file_outputs=self._file_outputs,
node_id=self._node_id, node_id=self._node_id,
node_type=self.node_type, node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format, reasoning_format=self.node_data.reasoning_format,
) )
structured_output: LLMStructuredOutput | None = None structured_output: LLMStructuredOutput | None = None
@ -253,12 +251,12 @@ class LLMNode(Node[LLMNodeData]):
reasoning_content = event.reasoning_content or "" reasoning_content = event.reasoning_content or ""
# For downstream nodes, determine clean text based on reasoning_format # For downstream nodes, determine clean text based on reasoning_format
if self._node_data.reasoning_format == "tagged": if self.node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility # Keep <think> tags for backward compatibility
clean_text = result_text clean_text = result_text
else: else:
# Extract clean text from <think> tags # Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format) clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
# Process structured output if available from the event. # Process structured output if available from the event.
structured_output = ( structured_output = (
@ -1204,7 +1202,7 @@ class LLMNode(Node[LLMNodeData]):
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled return self.node_data.retry_config.retry_enabled
def _combine_message_content_with_role( def _combine_message_content_with_role(

View File

@ -11,8 +11,6 @@ class LoopEndNode(Node[LoopEndNodeData]):
node_type = NodeType.LOOP_END node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

View File

@ -46,7 +46,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
""" """
node_type = NodeType.LOOP node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER execution_type = NodeExecutionType.CONTAINER
@classmethod @classmethod
@ -56,27 +55,27 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _run(self) -> Generator: def _run(self) -> Generator:
"""Run the node.""" """Run the node."""
# Get inputs # Get inputs
loop_count = self._node_data.loop_count loop_count = self.node_data.loop_count
break_conditions = self._node_data.break_conditions break_conditions = self.node_data.break_conditions
logical_operator = self._node_data.logical_operator logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count} inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id: if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found") raise ValueError(f"field start_node_id in loop {self._node_id} not found")
root_node_id = self._node_data.start_node_id root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool # Initialize loop variables in the original variable pool
loop_variable_selectors = {} loop_variable_selectors = {}
if self._node_data.loop_variables: if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value) "variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list) if isinstance(var.value, list)
else None, else None,
} }
for loop_variable in self._node_data.loop_variables: for loop_variable in self.node_data.loop_variables:
if loop_variable.value_type not in value_processor: if loop_variable.value_type not in value_processor:
raise ValueError( raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
@ -164,7 +163,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
yield LoopNextEvent( yield LoopNextEvent(
index=i + 1, index=i + 1,
pre_loop_output=self._node_data.outputs, pre_loop_output=self.node_data.outputs,
) )
self._accumulate_usage(loop_usage) self._accumulate_usage(loop_usage)
@ -172,7 +171,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
yield LoopSucceededEvent( yield LoopSucceededEvent(
start_at=start_at, start_at=start_at,
inputs=inputs, inputs=inputs,
outputs=self._node_data.outputs, outputs=self.node_data.outputs,
steps=loop_count, steps=loop_count,
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
@ -194,7 +193,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
}, },
outputs=self._node_data.outputs, outputs=self.node_data.outputs,
inputs=inputs, inputs=inputs,
llm_usage=loop_usage, llm_usage=loop_usage,
) )
@ -252,11 +251,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if isinstance(event, GraphRunFailedEvent): if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error) raise Exception(event.error)
for loop_var in self._node_data.loop_variables or []: for loop_var in self.node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label] key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel) segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None self.node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1 self.node_data.outputs["loop_round"] = current_index + 1
return reach_break_node return reach_break_node

View File

@ -11,8 +11,6 @@ class LoopStartNode(Node[LoopStartNodeData]):
node_type = NodeType.LOOP_START node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

View File

@ -90,8 +90,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_type = NodeType.PARAMETER_EXTRACTOR node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData
_model_instance: ModelInstance | None = None _model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None _model_config: ModelConfigWithCredentialsEntity | None = None
@ -116,7 +114,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
""" """
Run the node. Run the node.
""" """
node_data = self._node_data node_data = self.node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query) variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else "" query = variable.text if variable else ""

View File

@ -47,8 +47,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = NodeType.QUESTION_CLASSIFIER node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH execution_type = NodeExecutionType.BRANCH
_node_data: QuestionClassifierNodeData
_file_outputs: list["File"] _file_outputs: list["File"]
_llm_file_saver: LLMFileSaver _llm_file_saver: LLMFileSaver
@ -82,7 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
return "1" return "1"
def _run(self): def _run(self):
node_data = self._node_data node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
# extract variables # extract variables

View File

@ -9,8 +9,6 @@ class StartNode(Node[StartNodeData]):
node_type = NodeType.START node_type = NodeType.START
execution_type = NodeExecutionType.ROOT execution_type = NodeExecutionType.ROOT
_node_data: StartNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

View File

@ -14,8 +14,6 @@ MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]): class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
@classmethod @classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
""" """
@ -35,14 +33,14 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
for variable_selector in self._node_data.variables: for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None variables[variable_name] = value.to_object() if value else None
# Run code # Run code
try: try:
result = CodeExecutor.execute_workflow_code_template( result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
) )
except CodeExecutionError as e: except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

View File

@ -47,8 +47,6 @@ class ToolNode(Node[ToolNodeData]):
node_type = NodeType.TOOL node_type = NodeType.TOOL
_node_data: ToolNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -59,13 +57,11 @@ class ToolNode(Node[ToolNodeData]):
""" """
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = self._node_data
# fetch tool icon # fetch tool icon
tool_info = { tool_info = {
"provider_type": node_data.provider_type.value, "provider_type": self.node_data.provider_type.value,
"provider_id": node_data.provider_id, "provider_id": self.node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier, "plugin_unique_identifier": self.node_data.plugin_unique_identifier,
} }
# get tool runtime # get tool runtime
@ -77,10 +73,10 @@ class ToolNode(Node[ToolNodeData]):
# But for backward compatibility with historical data # But for backward compatibility with historical data
# this version field judgment is still preserved here. # this version field judgment is still preserved here.
variable_pool: VariablePool | None = None variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None: if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime( tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
) )
except ToolNodeError as e: except ToolNodeError as e:
yield StreamCompletedEvent( yield StreamCompletedEvent(
@ -99,12 +95,12 @@ class ToolNode(Node[ToolNodeData]):
parameters = self._generate_parameters( parameters = self._generate_parameters(
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data, node_data=self.node_data,
) )
parameters_for_log = self._generate_parameters( parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data, node_data=self.node_data,
for_log=True, for_log=True,
) )
# get conversation id # get conversation id
@ -149,7 +145,7 @@ class ToolNode(Node[ToolNodeData]):
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}", error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
error_type=type(e).__name__, error_type=type(e).__name__,
) )
) )
@ -159,7 +155,7 @@ class ToolNode(Node[ToolNodeData]):
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=e.to_user_friendly_error(plugin_name=node_data.provider_name), error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
error_type=type(e).__name__, error_type=type(e).__name__,
) )
) )
@ -495,4 +491,4 @@ class ToolNode(Node[ToolNodeData]):
@property @property
def retry(self) -> bool: def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled return self.node_data.retry_config.retry_enabled

View File

@ -43,9 +43,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
# Get trigger data passed when workflow was triggered # Get trigger data passed when workflow was triggered
metadata = { metadata = {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self._node_data.provider_id, "provider_id": self.node_data.provider_id,
"event_name": self._node_data.event_name, "event_name": self.node_data.event_name,
"plugin_unique_identifier": self._node_data.plugin_unique_identifier, "plugin_unique_identifier": self.node_data.plugin_unique_identifier,
}, },
} }
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)

View File

@ -84,7 +84,7 @@ class TriggerWebhookNode(Node[WebhookData]):
webhook_headers = webhook_data.get("headers", {}) webhook_headers = webhook_data.get("headers", {})
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()} webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
for header in self._node_data.headers: for header in self.node_data.headers:
header_name = header.name header_name = header.name
value = _get_normalized(webhook_headers, header_name) value = _get_normalized(webhook_headers, header_name)
if value is None: if value is None:
@ -93,20 +93,20 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[sanitized_name] = value outputs[sanitized_name] = value
# Extract configured query parameters # Extract configured query parameters
for param in self._node_data.params: for param in self.node_data.params:
param_name = param.name param_name = param.name
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name) outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
# Extract configured body parameters # Extract configured body parameters
for body_param in self._node_data.body: for body_param in self.node_data.body:
param_name = body_param.name param_name = body_param.name
param_type = body_param.type param_type = body_param.type
if self._node_data.content_type == ContentType.TEXT: if self.node_data.content_type == ContentType.TEXT:
# For text/plain, the entire body is a single string parameter # For text/plain, the entire body is a single string parameter
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue continue
elif self._node_data.content_type == ContentType.BINARY: elif self.node_data.content_type == ContentType.BINARY:
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"") outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
continue continue

View File

@ -10,8 +10,6 @@ from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorN
class VariableAggregatorNode(Node[VariableAggregatorNodeData]): class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
node_type = NodeType.VARIABLE_AGGREGATOR node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAggregatorNodeData
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
@ -21,8 +19,8 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
outputs: dict[str, Segment | Mapping[str, Segment]] = {} outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {} inputs = {}
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled: if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self._node_data.variables: for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None: if variable is not None:
outputs = {"output": variable} outputs = {"output": variable}
@ -30,7 +28,7 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
inputs = {".".join(selector[1:]): variable.to_object()} inputs = {".".join(selector[1:]): variable.to_object()}
break break
else: else:
for group in self._node_data.advanced_settings.groups: for group in self.node_data.advanced_settings.groups:
for selector in group.variables: for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)

View File

@ -25,8 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
def __init__( def __init__(
self, self,
id: str, id: str,
@ -71,21 +69,21 @@ class VariableAssignerNode(Node[VariableAssignerData]):
return mapping return mapping
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
assigned_variable_selector = self._node_data.assigned_variable_selector assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
match self._node_data.write_mode: match self.node_data.write_mode:
case WriteMode.OVER_WRITE: case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value}) updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND: case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value: if not income_value:
raise VariableOperatorNodeError("input value not found") raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value] updated_value = original_variable.value + [income_value.value]

View File

@ -53,8 +53,6 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]): class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
""" """
Check if this Variable Assigner node blocks the output of specific variables. Check if this Variable Assigner node blocks the output of specific variables.
@ -62,7 +60,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
Returns True if this node updates any of the requested conversation variables. Returns True if this node updates any of the requested conversation variables.
""" """
# Check each item in this Variable Assigner node # Check each item in this Variable Assigner node
for item in self._node_data.items: for item in self.node_data.items:
# Convert the item's variable_selector to tuple for comparison # Convert the item's variable_selector to tuple for comparison
item_selector_tuple = tuple(item.variable_selector) item_selector_tuple = tuple(item.variable_selector)
@ -97,13 +95,13 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return var_mapping return var_mapping
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
inputs = self._node_data.model_dump() inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {} process_data: dict[str, Any] = {}
# NOTE: This node has no outputs # NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = [] updated_variable_selectors: list[Sequence[str]] = []
try: try:
for item in self._node_data.items: for item in self.node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part # ==================== Validation Part