diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index af8b7e4e17..919b135ec9 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -79,29 +79,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not app_record: raise ValueError("App not found") - if self.application_generate_entity.single_iteration_run: - # if only single iteration run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + # Handle single iteration or single loop run + graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, - node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), - graph_runtime_state=graph_runtime_state, - ) - elif self.application_generate_entity.single_loop_run: - # if only single loop run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=self._workflow, - node_id=self.application_generate_entity.single_loop_run.node_id, - user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), - graph_runtime_state=graph_runtime_state, + single_iteration_run=self.application_generate_entity.single_iteration_run, + single_loop_run=self.application_generate_entity.single_loop_run, ) else: inputs = self.application_generate_entity.inputs diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 76627b876b..bd077c4cb8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -427,6 +427,9 @@ class PipelineGenerator(BaseAppGenerator): invoke_from=InvokeFrom.DEBUGGER, call_depth=0, workflow_execution_id=str(uuid.uuid4()), + single_iteration_run=RagPipelineGenerateEntity.SingleIterationRunEntity( + node_id=node_id, inputs=args["inputs"] + ), ) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) @@ -465,6 +468,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + context=contextvars.copy_context(), ) def single_loop_generate( @@ -559,6 +563,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, variable_loader=var_loader, + context=contextvars.copy_context(), ) def _generate_worker( diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index ebb8b15163..145f629c4d 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -86,29 +86,12 @@ class PipelineRunner(WorkflowBasedAppRunner): db.session.close() # if only single iteration run is requested - if self.application_generate_entity.single_iteration_run: - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - # if only single iteration run is requested - graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + # Handle single iteration or single loop run + graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=workflow, - node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs, - graph_runtime_state=graph_runtime_state, - ) - elif self.application_generate_entity.single_loop_run: - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - # if only single loop run is requested - graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, - node_id=self.application_generate_entity.single_loop_run.node_id, - user_inputs=self.application_generate_entity.single_loop_run.inputs, - graph_runtime_state=graph_runtime_state, + single_iteration_run=self.application_generate_entity.single_iteration_run, + single_loop_run=self.application_generate_entity.single_loop_run, ) else: inputs = self.application_generate_entity.inputs diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index b009dc7715..943ae8ab4e 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -51,30 +51,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) - # if only single iteration run is requested - if self.application_generate_entity.single_iteration_run: - # if only single iteration run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + # if only single iteration or single loop run is requested + if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: + graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution( workflow=self._workflow, - node_id=self.application_generate_entity.single_iteration_run.node_id, - user_inputs=self.application_generate_entity.single_iteration_run.inputs, - graph_runtime_state=graph_runtime_state, - ) - elif self.application_generate_entity.single_loop_run: - # if only single loop run is requested - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool.empty(), - start_at=time.time(), - ) - graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=self._workflow, - node_id=self.application_generate_entity.single_loop_run.node_id, - user_inputs=self.application_generate_entity.single_loop_run.inputs, - graph_runtime_state=graph_runtime_state, + single_iteration_run=self.application_generate_entity.single_iteration_run, + single_loop_run=self.application_generate_entity.single_loop_run, ) else: inputs = self.application_generate_entity.inputs diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 056e03fa14..564daba86d 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,3 +1,4 @@ +import time from collections.abc import Mapping from typing import Any, cast @@ -119,15 +120,81 @@ class WorkflowBasedAppRunner: return graph - def _get_graph_and_variable_pool_of_single_iteration( + def _prepare_single_node_execution( + self, + workflow: Workflow, + single_iteration_run: Any | None = None, + single_loop_run: Any | None = None, + ) -> tuple[Graph, VariablePool, GraphRuntimeState]: + """ + Prepare graph, variable pool, and runtime state for single node execution + (either single iteration or single loop). + + Args: + workflow: The workflow instance + single_iteration_run: SingleIterationRunEntity if running single iteration, None otherwise + single_loop_run: SingleLoopRunEntity if running single loop, None otherwise + + Returns: + A tuple containing (graph, variable_pool, graph_runtime_state) + + Raises: + ValueError: If neither single_iteration_run nor single_loop_run is specified + """ + # Create initial runtime state with variable pool containing environment variables + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool( + system_variables=SystemVariable.empty(), + user_inputs={}, + environment_variables=workflow.environment_variables, + ), + start_at=time.time(), + ) + + # Determine which type of single node execution and get graph/variable_pool + if single_iteration_run: + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=single_iteration_run.node_id, + user_inputs=dict(single_iteration_run.inputs), + graph_runtime_state=graph_runtime_state, + ) + elif single_loop_run: + graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( + workflow=workflow, + node_id=single_loop_run.node_id, + user_inputs=dict(single_loop_run.inputs), + graph_runtime_state=graph_runtime_state, + ) + else: + raise ValueError("Neither single_iteration_run nor single_loop_run is specified") + + # Return the graph, variable_pool, and the same graph_runtime_state used during graph creation + # This ensures all nodes in the graph reference the same GraphRuntimeState instance + return graph, variable_pool, graph_runtime_state + + def _get_graph_and_variable_pool_for_single_node_run( self, workflow: Workflow, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], graph_runtime_state: GraphRuntimeState, + node_type_filter_key: str, # 'iteration_id' or 'loop_id' + node_type_label: str = "node", # 'iteration' or 'loop' for error messages ) -> tuple[Graph, VariablePool]: """ - Get variable pool of single iteration + Get graph and variable pool for single node execution (iteration or loop). + + Args: + workflow: The workflow instance + node_id: The node ID to execute + user_inputs: User inputs for the node + graph_runtime_state: The graph runtime state + node_type_filter_key: The key to filter nodes ('iteration_id' or 'loop_id') + node_type_label: Label for error messages ('iteration' or 'loop') + + Returns: + A tuple containing (graph, variable_pool) """ # fetch workflow graph graph_config = workflow.graph_dict @@ -145,18 +212,22 @@ class WorkflowBasedAppRunner: if not isinstance(graph_config.get("edges"), list): raise ValueError("edges in workflow graph must be a list") - # filter nodes only in iteration + # filter nodes only in the specified node type (iteration or loop) + main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None) + start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None node_configs = [ node for node in graph_config.get("nodes", []) - if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id + if node.get("id") == node_id + or node.get("data", {}).get(node_type_filter_key, "") == node_id + or (start_node_id and node.get("id") == start_node_id) ] graph_config["nodes"] = node_configs node_ids = [node.get("id") for node in node_configs] - # filter edges only in iteration + # filter edges only in the specified node type edge_configs = [ edge for edge in graph_config.get("edges", []) @@ -190,30 +261,26 @@ class WorkflowBasedAppRunner: raise ValueError("graph not found in workflow") # fetch node config from node id - iteration_node_config = None + target_node_config = None for node in node_configs: if node.get("id") == node_id: - iteration_node_config = node + target_node_config = node break - if not iteration_node_config: - raise ValueError("iteration node id not found in workflow graph") + if not target_node_config: + raise ValueError(f"{node_type_label} node id not found in workflow graph") # Get node class - node_type = NodeType(iteration_node_config.get("data", {}).get("type")) - node_version = iteration_node_config.get("data", {}).get("version", "1") + node_type = NodeType(target_node_config.get("data", {}).get("type")) + node_version = target_node_config.get("data", {}).get("version", "1") node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] - # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - environment_variables=workflow.environment_variables, - ) + # Use the variable pool from graph_runtime_state instead of creating a new one + variable_pool = graph_runtime_state.variable_pool try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=iteration_node_config + graph_config=workflow.graph_dict, config=target_node_config ) except NotImplementedError: variable_mapping = {} @@ -234,120 +301,44 @@ class WorkflowBasedAppRunner: return graph, variable_pool + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict[str, Any], + graph_runtime_state: GraphRuntimeState, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + return self._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id=node_id, + user_inputs=user_inputs, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="iteration_id", + node_type_label="iteration", + ) + def _get_graph_and_variable_pool_of_single_loop( self, workflow: Workflow, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], graph_runtime_state: GraphRuntimeState, ) -> tuple[Graph, VariablePool]: """ Get variable pool of single loop """ - # fetch workflow graph - graph_config = workflow.graph_dict - if not graph_config: - raise ValueError("workflow graph not found") - - graph_config = cast(dict[str, Any], graph_config) - - if "nodes" not in graph_config or "edges" not in graph_config: - raise ValueError("nodes or edges not found in workflow graph") - - if not isinstance(graph_config.get("nodes"), list): - raise ValueError("nodes in workflow graph must be a list") - - if not isinstance(graph_config.get("edges"), list): - raise ValueError("edges in workflow graph must be a list") - - # filter nodes only in loop - node_configs = [ - node - for node in graph_config.get("nodes", []) - if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id - ] - - graph_config["nodes"] = node_configs - - node_ids = [node.get("id") for node in node_configs] - - # filter edges only in loop - edge_configs = [ - edge - for edge in graph_config.get("edges", []) - if (edge.get("source") is None or edge.get("source") in node_ids) - and (edge.get("target") is None or edge.get("target") in node_ids) - ] - - graph_config["edges"] = edge_configs - - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, - workflow_id=workflow.id, - graph_config=graph_config, - user_id="", - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.SERVICE_API.value, - call_depth=0, - ) - - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + return self._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id=node_id, + user_inputs=user_inputs, graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", ) - # init graph - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=node_id) - - if not graph: - raise ValueError("graph not found in workflow") - - # fetch node config from node id - loop_node_config = None - for node in node_configs: - if node.get("id") == node_id: - loop_node_config = node - break - - if not loop_node_config: - raise ValueError("loop node id not found in workflow graph") - - # Get node class - node_type = NodeType(loop_node_config.get("data", {}).get("type")) - node_version = loop_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] - - # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.empty(), - user_inputs={}, - environment_variables=workflow.environment_variables, - ) - - try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=loop_node_config - ) - except NotImplementedError: - variable_mapping = {} - load_into_variable_pool( - self._variable_loader, - variable_pool=variable_pool, - variable_mapping=variable_mapping, - user_inputs=user_inputs, - ) - - WorkflowEntry.mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - ) - - return graph, variable_pool - def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): """ Handle event diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 5340a5b6ce..6e57b17d5c 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -372,43 +372,16 @@ class IterationNode(Node): variable_mapping: dict[str, Sequence[str]] = { f"{node_id}.input_selector": typed_node_data.iterator_selector, } + iteration_node_ids = set() - # init graph - from core.workflow.entities import GraphInitParams, GraphRuntimeState - from core.workflow.graph import Graph - from core.workflow.nodes.node_factory import DifyNodeFactory - - # Create minimal GraphInitParams for static analysis - graph_init_params = GraphInitParams( - tenant_id="", - app_id="", - workflow_id="", - graph_config=graph_config, - user_id="", - user_from="", - invoke_from="", - call_depth=0, - ) - - # Create minimal GraphRuntimeState for static analysis - from core.workflow.entities import VariablePool - - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(), - start_at=0, - ) - - # Create node factory for static analysis - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - iteration_graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=typed_node_data.start_node_id, - ) - - if not iteration_graph: - raise IterationGraphNotFoundError("iteration graph not found") + # Find all nodes that belong to this loop + nodes = graph_config.get("nodes", []) + for node in nodes: + node_data = node.get("data", {}) + if node_data.get("iteration_id") == node_id: + in_iteration_node_id = node.get("id") + if in_iteration_node_id: + iteration_node_ids.add(in_iteration_node_id) # Get node configs from graph_config instead of non-existent node_id_config_mapping node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} @@ -444,9 +417,7 @@ class IterationNode(Node): variable_mapping.update(sub_node_variable_mapping) # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids - } + variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} return variable_mapping diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 2b988ad944..790975d556 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,3 +1,4 @@ +import contextlib import json import logging from collections.abc import Callable, Generator, Mapping, Sequence @@ -127,11 +128,13 @@ class LoopNode(Node): try: reach_break_condition = False if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) + with contextlib.suppress(ValueError): + _, _, reach_break_condition = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=break_conditions, + operator=logical_operator, + ) + if reach_break_condition: loop_count = 0 cost_tokens = 0 @@ -295,42 +298,11 @@ class LoopNode(Node): variable_mapping = {} - # init graph - from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool - from core.workflow.graph import Graph - from core.workflow.nodes.node_factory import DifyNodeFactory + # Extract loop node IDs statically from graph_config - # Create minimal GraphInitParams for static analysis - graph_init_params = GraphInitParams( - tenant_id="", - app_id="", - workflow_id="", - graph_config=graph_config, - user_id="", - user_from="", - invoke_from="", - call_depth=0, - ) + loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - # Create minimal GraphRuntimeState for static analysis - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(), - start_at=0, - ) - - # Create node factory for static analysis - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - loop_graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=typed_node_data.start_node_id, - ) - - if not loop_graph: - raise ValueError("loop graph not found") - - # Get node configs from graph_config instead of non-existent node_id_config_mapping + # Get node configs from graph_config node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} for sub_node_id, sub_node_config in node_configs.items(): if sub_node_config.get("data", {}).get("loop_id") != node_id: @@ -371,12 +343,35 @@ class LoopNode(Node): variable_mapping[f"{node_id}.{loop_variable.label}"] = selector # remove variable out from loop - variable_mapping = { - key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids - } + variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} return variable_mapping + @classmethod + def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: + """ + Extract node IDs that belong to a specific loop from graph configuration. + + This method statically analyzes the graph configuration to find all nodes + that are part of the specified loop, without creating actual node instances. + + :param graph_config: the complete graph configuration + :param loop_node_id: the ID of the loop node + :return: set of node IDs that belong to the loop + """ + loop_node_ids = set() + + # Find all nodes that belong to this loop + nodes = graph_config.get("nodes", []) + for node in nodes: + node_data = node.get("data", {}) + if node_data.get("loop_id") == loop_node_id: + node_id = node.get("id") + if node_id: + loop_node_ids.add(node_id) + + return loop_node_ids + @staticmethod def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: """Get the appropriate segment type for a constant value."""