From e93040f63c06ece52e6340ab328ecd94b0793a2c Mon Sep 17 00:00:00 2001 From: Novice Date: Wed, 20 Aug 2025 16:27:50 +0800 Subject: [PATCH] fix: iteration and loop node single step run --- api/core/app/apps/advanced_chat/app_runner.py | 2 + api/core/app/apps/workflow_app_runner.py | 8 +- .../nodes/iteration/iteration_node.py | 49 +++--------- api/core/workflow/nodes/loop/loop_node.py | 79 +++++++++---------- 4 files changed, 55 insertions(+), 83 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 452dbbec01..d055832c03 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -89,6 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), graph_runtime_state=graph_runtime_state, ) + graph_runtime_state.variable_pool = variable_pool elif self.application_generate_entity.single_loop_run: # if only single loop run is requested graph_runtime_state = GraphRuntimeState( @@ -101,6 +102,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), graph_runtime_state=graph_runtime_state, ) + graph_runtime_state.variable_pool = variable_pool else: inputs = self.application_generate_entity.inputs query = self.application_generate_entity.query diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 5d9f64f8b6..cc7b0ac5dc 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -149,7 +149,9 @@ class WorkflowBasedAppRunner: node_configs = [ node for node in graph_config.get("nodes", []) - if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id + if node.get("id") == node_id + or node.get("data", {}).get("iteration_id", "") == node_id + or node.get("id") == f"{node_id}start" ] graph_config["nodes"] = node_configs @@ -264,7 +266,9 @@ class WorkflowBasedAppRunner: node_configs = [ node for node in graph_config.get("nodes", []) - if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id + if node.get("id") == node_id + or node.get("data", {}).get("loop_id", "") == node_id + or node.get("id") == f"{node_id}start" ] graph_config["nodes"] = node_configs diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index ab7b648af0..fddb4cfe58 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -208,43 +208,16 @@ class IterationNode(Node): variable_mapping: dict[str, Sequence[str]] = { f"{node_id}.input_selector": typed_node_data.iterator_selector, } + iteration_node_ids = set() - # init graph - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.graph import Graph - from core.workflow.nodes.node_factory import DifyNodeFactory - - # Create minimal GraphInitParams for static analysis - graph_init_params = GraphInitParams( - tenant_id="", - app_id="", - workflow_id="", - graph_config=graph_config, - user_id="", - user_from="", - invoke_from="", - call_depth=0, - ) - - # Create minimal GraphRuntimeState for static analysis - from core.workflow.entities import VariablePool - - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(), - start_at=0, - ) - - # Create node factory for static analysis - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - iteration_graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=typed_node_data.start_node_id, - ) - - if not iteration_graph: - raise IterationGraphNotFoundError("iteration graph not found") + # Find all nodes that belong to this loop + nodes = graph_config.get("nodes", []) + for node in nodes: + node_data = node.get("data", {}) + if node_data.get("iteration_id") == node_id: + in_iteration_node_id = node.get("id") + if in_iteration_node_id: + iteration_node_ids.add(in_iteration_node_id) # Get node configs from graph_config instead of non-existent node_id_config_mapping node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} @@ -280,9 +253,7 @@ class IterationNode(Node): variable_mapping.update(sub_node_variable_mapping) # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids - } + variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} return variable_mapping diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index dffeee66f5..f13d4b1227 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,3 +1,4 @@ +import contextlib import json import logging from collections.abc import Callable, Generator, Mapping, Sequence @@ -126,11 +127,13 @@ class LoopNode(Node): try: reach_break_condition = False if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) + with contextlib.suppress(ValueError): + _, _, reach_break_condition = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=break_conditions, + operator=logical_operator, + ) + if reach_break_condition: loop_count = 0 cost_tokens = 0 @@ -294,42 +297,11 @@ class LoopNode(Node): variable_mapping = {} - # init graph - from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool - from core.workflow.graph import Graph - from core.workflow.nodes.node_factory import DifyNodeFactory + # Extract loop node IDs statically from graph_config - # Create minimal GraphInitParams for static analysis - graph_init_params = GraphInitParams( - tenant_id="", - app_id="", - workflow_id="", - graph_config=graph_config, - user_id="", - user_from="", - invoke_from="", - call_depth=0, - ) + loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - # Create minimal GraphRuntimeState for static analysis - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(), - start_at=0, - ) - - # Create node factory for static analysis - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - loop_graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=typed_node_data.start_node_id, - ) - - if not loop_graph: - raise ValueError("loop graph not found") - - # Get node configs from graph_config instead of non-existent node_id_config_mapping + # Get node configs from graph_config node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} for sub_node_id, sub_node_config in node_configs.items(): if sub_node_config.get("data", {}).get("loop_id") != node_id: @@ -370,12 +342,35 @@ class LoopNode(Node): variable_mapping[f"{node_id}.{loop_variable.label}"] = selector # remove variable out from loop - variable_mapping = { - key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids - } + variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} return variable_mapping + @classmethod + def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: + """ + Extract node IDs that belong to a specific loop from graph configuration. + + This method statically analyzes the graph configuration to find all nodes + that are part of the specified loop, without creating actual node instances. + + :param graph_config: the complete graph configuration + :param loop_node_id: the ID of the loop node + :return: set of node IDs that belong to the loop + """ + loop_node_ids = set() + + # Find all nodes that belong to this loop + nodes = graph_config.get("nodes", []) + for node in nodes: + node_data = node.get("data", {}) + if node_data.get("loop_id") == loop_node_id: + node_id = node.get("id") + if node_id: + loop_node_ids.add(node_id) + + return loop_node_ids + @staticmethod def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: """Get the appropriate segment type for a constant value."""