From 8fb69429f9c9acdb06307004da03f0653d5e6171 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 11 Sep 2025 15:37:46 +0800 Subject: [PATCH] feat(graph_engine): support parallel mode in iteration node Signed-off-by: -LAN- --- .../nodes/iteration/iteration_node.py | 120 ++++++++++++++++-- 1 file changed, 107 insertions(+), 13 deletions(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 10fe7473bb..1547f8ac3e 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,5 +1,6 @@ import logging from collections.abc import Generator, Mapping, Sequence +from concurrent.futures import Future, ThreadPoolExecutor, as_completed from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, NewType, cast @@ -190,22 +191,115 @@ class IterationNode(Node): outputs: list[Any], iter_run_map: dict[str, float], ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, + if self._node_data.is_parallel: + # Parallel mode execution + yield from self._execute_parallel_iterations( + iterator_list_value=iterator_list_value, outputs=outputs, - graph_engine=graph_engine, + iter_run_map=iter_run_map, ) + else: + # Sequential mode execution + for index, item in enumerate(iterator_list_value): + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + yield IterationNextEvent(index=index) - # Update the total tokens from this iteration - self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + graph_engine = self._create_graph_engine(index, item) + + # Run the iteration + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, + ) + + # Update the total tokens from this iteration + self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens + iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + + def _execute_parallel_iterations( + self, + iterator_list_value: Sequence[object], + outputs: list[Any], + iter_run_map: dict[str, float], + ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: + # Initialize outputs list with None values to maintain order + outputs.extend([None] * len(iterator_list_value)) + + # Determine the number of parallel workers + max_workers = min(self._node_data.parallel_nums, len(iterator_list_value)) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all iteration tasks + future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {} + for index, item in enumerate(iterator_list_value): + yield IterationNextEvent(index=index) + future = executor.submit( + self._execute_single_iteration_parallel, + index=index, + item=item, + ) + future_to_index[future] = index + + # Process completed iterations as they finish + for future in as_completed(future_to_index): + index = future_to_index[future] + try: + result = future.result() + iter_start_at, events, output_value, tokens_used = result + + # Update outputs at the correct index + outputs[index] = output_value + + # Yield all events from this iteration + yield from events + + # Update tokens and timing + self.graph_runtime_state.total_tokens += tokens_used + iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + + except Exception as e: + # Handle errors based on error_handle_mode + match self._node_data.error_handle_mode: + case ErrorHandleMode.TERMINATED: + # Cancel remaining futures and re-raise + for f in future_to_index: + if f != future: + f.cancel() + raise IterationNodeError(str(e)) + case ErrorHandleMode.CONTINUE_ON_ERROR: + outputs[index] = None + case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + outputs[index] = None # Will be filtered later + + # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode + if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + outputs[:] = [output for output in outputs if output is not None] + + def _execute_single_iteration_parallel( + self, + index: int, + item: Any, + ) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]: + """Execute a single iteration in parallel mode and return results.""" + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + events: list[GraphNodeEventBase] = [] + outputs_temp: list[object] = [] + + graph_engine = self._create_graph_engine(index, item) + + # Collect events instead of yielding them directly + for event in self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs_temp, + graph_engine=graph_engine, + ): + events.append(event) + + # Get the output value from the temporary outputs list + output_value = outputs_temp[0] if outputs_temp else None + + return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens def _handle_iteration_success( self,