Refactor workflow nodes to use generic node_data (#28782)

This commit is contained in:
-LAN- 2025-11-27 20:46:56 +08:00 committed by GitHub
parent 002d8769b0
commit 8b761319f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 121 additions and 170 deletions

View File

@ -70,7 +70,6 @@ class AgentNode(Node[AgentNodeData]):
"""
node_type = NodeType.AGENT
_node_data: AgentNodeData
@classmethod
def version(cls) -> str:
@ -82,8 +81,8 @@ class AgentNode(Node[AgentNodeData]):
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
agent_strategy_name=self._node_data.agent_strategy_name,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
@ -101,13 +100,13 @@ class AgentNode(Node[AgentNodeData]):
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
for_log=True,
strategy=strategy,
)
@ -140,7 +139,7 @@ class AgentNode(Node[AgentNodeData]):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": self._node_data.agent_strategy_name,
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@ -387,7 +386,7 @@ class AgentNode(Node[AgentNodeData]):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

View File

@ -14,14 +14,12 @@ class AnswerNode(Node[AnswerNodeData]):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer)
files = self._extract_files_from_segments(segments.value)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -71,4 +69,4 @@ class AnswerNode(Node[AnswerNodeData]):
Returns:
Template instance for this Answer node
"""
return Template.from_answer_template(self._node_data.answer)
return Template.from_answer_template(self.node_data.answer)

View File

@ -24,8 +24,6 @@ from .exc import (
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_node_data: CodeNodeData
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
@ -48,12 +46,12 @@ class CodeNode(Node[CodeNodeData]):
def _run(self) -> NodeRunResult:
# Get code language
code_language = self._node_data.code_language
code = self._node_data.code
code_language = self.node_data.code_language
code = self.node_data.code
# Get variables
variables = {}
for variable_selector in self._node_data.variables:
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
@ -69,7 +67,7 @@ class CodeNode(Node[CodeNodeData]):
)
# Transform result
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -406,7 +404,7 @@ class CodeNode(Node[CodeNodeData]):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled
@staticmethod
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:

View File

@ -42,7 +42,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
Datasource Node
"""
_node_data: DatasourceNodeData
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
@ -51,7 +50,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
Run the datasource node
"""
node_data = self._node_data
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement:

View File

@ -43,14 +43,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self._node_data.variable_selector
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:

View File

@ -9,8 +9,6 @@ class EndNode(Node[EndNodeData]):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
_node_data: EndNodeData
@classmethod
def version(cls) -> str:
return "1"
@ -22,7 +20,7 @@ class EndNode(Node[EndNodeData]):
This method runs after streaming is complete (if streaming was enabled).
It collects all output variables and returns them.
"""
output_variables = self._node_data.outputs
output_variables = self.node_data.outputs
outputs = {}
for variable_selector in output_variables:
@ -44,6 +42,6 @@ class EndNode(Node[EndNodeData]):
Template instance for this End node
"""
outputs_config = [
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
{"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs
]
return Template.from_end_outputs(outputs_config)

View File

@ -34,8 +34,6 @@ logger = logging.getLogger(__name__)
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@ -69,8 +67,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
process_data = {}
try:
http_executor = Executor(
node_data=self._node_data,
timeout=self._get_request_timeout(self._node_data),
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
@ -225,4 +223,4 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled

View File

@ -25,8 +25,6 @@ class HumanInputNode(Node[HumanInputNodeData]):
"handle",
)
_node_data: HumanInputNodeData
@classmethod
def version(cls) -> str:
return "1"
@ -49,12 +47,12 @@ class HumanInputNode(Node[HumanInputNodeData]):
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables:
if not self.node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables:
for selector_str in self.node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
@ -74,7 +72,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
if handle:
return handle
default_values = self._node_data.default_value_dict
default_values = self.node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:

View File

@ -16,8 +16,6 @@ class IfElseNode(Node[IfElseNodeData]):
node_type = NodeType.IF_ELSE
execution_type = NodeExecutionType.BRANCH
_node_data: IfElseNodeData
@classmethod
def version(cls) -> str:
return "1"
@ -37,8 +35,8 @@ class IfElseNode(Node[IfElseNodeData]):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if self._node_data.cases:
for case in self._node_data.cases:
if self.node_data.cases:
for case in self.node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
@ -64,8 +62,8 @@ class IfElseNode(Node[IfElseNodeData]):
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self._node_data.conditions or [],
operator=self._node_data.logical_operator or "and",
conditions=self.node_data.conditions or [],
operator=self.node_data.logical_operator or "and",
)
selected_case_id = "true" if final_result else "false"

View File

@ -65,7 +65,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
node_type = NodeType.ITERATION
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -136,10 +135,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@ -174,7 +173,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
if not self.node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
def _execute_iterations(
@ -184,7 +183,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
if self.node_data.is_parallel:
# Parallel mode execution
yield from self._execute_parallel_iterations(
iterator_list_value=iterator_list_value,
@ -231,7 +230,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
outputs.extend([None] * len(iterator_list_value))
# Determine the number of parallel workers
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
@ -287,7 +286,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
except Exception as e:
# Handle errors based on error_handle_mode
match self._node_data.error_handle_mode:
match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
# Cancel remaining futures and re-raise
for f in future_to_index:
@ -300,7 +299,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
outputs[index] = None # Will be filtered later
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs[:] = [output for output in outputs if output is not None]
def _execute_single_iteration_parallel(
@ -389,7 +388,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
If flatten_output is True (default), flattens the list if all elements are lists.
"""
# If flatten_output is disabled, return outputs as-is
if not self._node_data.flatten_output:
if not self.node_data.flatten_output:
return outputs
if not outputs:
@ -569,14 +568,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
yield event
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
result = variable_pool.get(self._node_data.output_selector)
result = variable_pool.get(self.node_data.output_selector)
if result is None:
outputs.append(None)
else:
outputs.append(result.to_object())
return
elif isinstance(event, GraphRunFailedEvent):
match self._node_data.error_handle_mode:
match self.node_data.error_handle_mode:
case ErrorHandleMode.TERMINATED:
raise IterationNodeError(event.error)
case ErrorHandleMode.CONTINUE_ON_ERROR:
@ -627,7 +626,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
# Initialize the iteration graph with the new node factory
iteration_graph = Graph.init(
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
)
if not iteration_graph:

View File

@ -11,8 +11,6 @@ class IterationStartNode(Node[IterationStartNodeData]):
node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
@classmethod
def version(cls) -> str:
return "1"

View File

@ -35,12 +35,11 @@ default_retrieval_model = {
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
_node_data: KnowledgeIndexNodeData
node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE
def _run(self) -> NodeRunResult: # type: ignore
node_data = self._node_data
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
if not dataset_id:

View File

@ -83,8 +83,6 @@ default_retrieval_model = {
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@ -122,7 +120,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def _run(self) -> NodeRunResult:
# extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -163,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -536,7 +534,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,

View File

@ -37,8 +37,6 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
class ListOperatorNode(Node[ListOperatorNodeData]):
node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
@classmethod
def version(cls) -> str:
return "1"
@ -48,9 +46,9 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
process_data: dict[str, Sequence[object]] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self._node_data.variable}"
error_message = f"Variable not found for selector: {self.node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -69,7 +67,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
outputs=outputs,
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -83,19 +81,19 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
try:
# Filter
if self._node_data.filter_by.enabled:
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
if self._node_data.extract_by.enabled:
if self.node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
if self._node_data.order_by.enabled:
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self._node_data.limit.enabled:
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
@ -121,7 +119,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@ -160,22 +158,22 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC)
variable = variable.model_copy(update={"value": result})
else:
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
result = variable.value[: self.node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
if value > len(variable.value):

View File

@ -102,8 +102,6 @@ logger = logging.getLogger(__name__)
class LLMNode(Node[LLMNodeData]):
node_type = NodeType.LLM
_node_data: LLMNodeData
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
@ -154,13 +152,13 @@ class LLMNode(Node[LLMNodeData]):
try:
# init messages template
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self._node_data)
inputs = self._fetch_inputs(node_data=self.node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
# merge inputs
inputs.update(jinja_inputs)
@ -169,9 +167,9 @@ class LLMNode(Node[LLMNodeData]):
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self._node_data.vision.configs.variable_selector,
selector=self.node_data.vision.configs.variable_selector,
)
if self._node_data.vision.enabled
if self.node_data.vision.enabled
else []
)
@ -179,7 +177,7 @@ class LLMNode(Node[LLMNodeData]):
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data=self._node_data)
generator = self._fetch_context(node_data=self.node_data)
context = None
for event in generator:
context = event.context
@ -189,7 +187,7 @@ class LLMNode(Node[LLMNodeData]):
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self._node_data.model,
node_data_model=self.node_data.model,
tenant_id=self.tenant_id,
)
@ -197,13 +195,13 @@ class LLMNode(Node[LLMNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self._node_data.memory,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
)
query: str | None = None
if self._node_data.memory:
query = self._node_data.memory.query_prompt_template
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
@ -215,29 +213,29 @@ class LLMNode(Node[LLMNodeData]):
context=context,
memory=memory,
model_config=model_config,
prompt_template=self._node_data.prompt_template,
memory_config=self._node_data.memory,
vision_enabled=self._node_data.vision.enabled,
vision_detail=self._node_data.vision.configs.detail,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
)
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model,
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format,
reasoning_format=self.node_data.reasoning_format,
)
structured_output: LLMStructuredOutput | None = None
@ -253,12 +251,12 @@ class LLMNode(Node[LLMNodeData]):
reasoning_content = event.reasoning_content or ""
# For downstream nodes, determine clean text based on reasoning_format
if self._node_data.reasoning_format == "tagged":
if self.node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility
clean_text = result_text
else:
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
# Process structured output if available from the event.
structured_output = (
@ -1204,7 +1202,7 @@ class LLMNode(Node[LLMNodeData]):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled
def _combine_message_content_with_role(

View File

@ -11,8 +11,6 @@ class LoopEndNode(Node[LoopEndNodeData]):
node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
@classmethod
def version(cls) -> str:
return "1"

View File

@ -46,7 +46,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
"""
node_type = NodeType.LOOP
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
@classmethod
@ -56,27 +55,27 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _run(self) -> Generator:
"""Run the node."""
# Get inputs
loop_count = self._node_data.loop_count
break_conditions = self._node_data.break_conditions
logical_operator = self._node_data.logical_operator
loop_count = self.node_data.loop_count
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count}
if not self._node_data.start_node_id:
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
root_node_id = self._node_data.start_node_id
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
if self._node_data.loop_variables:
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list)
else None,
}
for loop_variable in self._node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
if loop_variable.value_type not in value_processor:
raise ValueError(
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
@ -164,7 +163,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
yield LoopNextEvent(
index=i + 1,
pre_loop_output=self._node_data.outputs,
pre_loop_output=self.node_data.outputs,
)
self._accumulate_usage(loop_usage)
@ -172,7 +171,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
yield LoopSucceededEvent(
start_at=start_at,
inputs=inputs,
outputs=self._node_data.outputs,
outputs=self.node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
@ -194,7 +193,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self._node_data.outputs,
outputs=self.node_data.outputs,
inputs=inputs,
llm_usage=loop_usage,
)
@ -252,11 +251,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
if isinstance(event, GraphRunFailedEvent):
raise Exception(event.error)
for loop_var in self._node_data.loop_variables or []:
for loop_var in self.node_data.loop_variables or []:
key, sel = loop_var.label, [self._node_id, loop_var.label]
segment = self.graph_runtime_state.variable_pool.get(sel)
self._node_data.outputs[key] = segment.value if segment else None
self._node_data.outputs["loop_round"] = current_index + 1
self.node_data.outputs[key] = segment.value if segment else None
self.node_data.outputs["loop_round"] = current_index + 1
return reach_break_node

View File

@ -11,8 +11,6 @@ class LoopStartNode(Node[LoopStartNodeData]):
node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
@classmethod
def version(cls) -> str:
return "1"

View File

@ -90,8 +90,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
@ -116,7 +114,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Run the node.
"""
node_data = self._node_data
node_data = self.node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""

View File

@ -47,8 +47,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
_node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
@ -82,7 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
return "1"
def _run(self):
node_data = self._node_data
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables

View File

@ -9,8 +9,6 @@ class StartNode(Node[StartNodeData]):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
_node_data: StartNodeData
@classmethod
def version(cls) -> str:
return "1"

View File

@ -14,8 +14,6 @@ MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
@ -35,14 +33,14 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
def _run(self) -> NodeRunResult:
# Get variables
variables: dict[str, Any] = {}
for variable_selector in self._node_data.variables:
for variable_selector in self.node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

View File

@ -47,8 +47,6 @@ class ToolNode(Node[ToolNodeData]):
node_type = NodeType.TOOL
_node_data: ToolNodeData
@classmethod
def version(cls) -> str:
return "1"
@ -59,13 +57,11 @@ class ToolNode(Node[ToolNodeData]):
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = self._node_data
# fetch tool icon
tool_info = {
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
"provider_type": self.node_data.provider_type.value,
"provider_id": self.node_data.provider_id,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
}
# get tool runtime
@ -77,10 +73,10 @@ class ToolNode(Node[ToolNodeData]):
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield StreamCompletedEvent(
@ -99,12 +95,12 @@ class ToolNode(Node[ToolNodeData]):
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
node_data=self.node_data,
for_log=True,
)
# get conversation id
@ -149,7 +145,7 @@ class ToolNode(Node[ToolNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
error_type=type(e).__name__,
)
)
@ -159,7 +155,7 @@ class ToolNode(Node[ToolNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=e.to_user_friendly_error(plugin_name=node_data.provider_name),
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
error_type=type(e).__name__,
)
)
@ -495,4 +491,4 @@ class ToolNode(Node[ToolNodeData]):
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
return self.node_data.retry_config.retry_enabled

View File

@ -43,9 +43,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
# Get trigger data passed when workflow was triggered
metadata = {
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
"provider_id": self._node_data.provider_id,
"event_name": self._node_data.event_name,
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
"provider_id": self.node_data.provider_id,
"event_name": self.node_data.event_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
},
}
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)

View File

@ -84,7 +84,7 @@ class TriggerWebhookNode(Node[WebhookData]):
webhook_headers = webhook_data.get("headers", {})
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
for header in self._node_data.headers:
for header in self.node_data.headers:
header_name = header.name
value = _get_normalized(webhook_headers, header_name)
if value is None:
@ -93,20 +93,20 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[sanitized_name] = value
# Extract configured query parameters
for param in self._node_data.params:
for param in self.node_data.params:
param_name = param.name
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
# Extract configured body parameters
for body_param in self._node_data.body:
for body_param in self.node_data.body:
param_name = body_param.name
param_type = body_param.type
if self._node_data.content_type == ContentType.TEXT:
if self.node_data.content_type == ContentType.TEXT:
# For text/plain, the entire body is a single string parameter
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
continue
elif self._node_data.content_type == ContentType.BINARY:
elif self.node_data.content_type == ContentType.BINARY:
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
continue

View File

@ -10,8 +10,6 @@ from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorN
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAggregatorNodeData
@classmethod
def version(cls) -> str:
return "1"
@ -21,8 +19,8 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
for selector in self._node_data.variables:
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
@ -30,7 +28,7 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
inputs = {".".join(selector[1:]): variable.to_object()}
break
else:
for group in self._node_data.advanced_settings.groups:
for group in self.node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)

View File

@ -25,8 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
def __init__(
self,
id: str,
@ -71,21 +69,21 @@ class VariableAssignerNode(Node[VariableAssignerData]):
return mapping
def _run(self) -> NodeRunResult:
assigned_variable_selector = self._node_data.assigned_variable_selector
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
match self._node_data.write_mode:
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]

View File

@ -53,8 +53,6 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
@ -62,7 +60,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
Returns True if this node updates any of the requested conversation variables.
"""
# Check each item in this Variable Assigner node
for item in self._node_data.items:
for item in self.node_data.items:
# Convert the item's variable_selector to tuple for comparison
item_selector_tuple = tuple(item.variable_selector)
@ -97,13 +95,13 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
return var_mapping
def _run(self) -> NodeRunResult:
inputs = self._node_data.model_dump()
inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = []
try:
for item in self._node_data.items:
for item in self.node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part