mirror of https://github.com/langgenius/dify.git
feat(graph_engine): support parallel mode in iteration node
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
85064bd8cf
commit
8fb69429f9
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
|
|
@ -190,22 +191,115 @@ class IterationNode(Node):
|
|||
outputs: list[Any],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield IterationNextEvent(index=index)
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Run the iteration
|
||||
yield from self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
if self._node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
yield from self._execute_parallel_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
graph_engine=graph_engine,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
else:
|
||||
# Sequential mode execution
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield IterationNextEvent(index=index)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Run the iteration
|
||||
yield from self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs,
|
||||
graph_engine=graph_engine,
|
||||
)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
def _execute_parallel_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[Any],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
# Initialize outputs list with None values to maintain order
|
||||
outputs.extend([None] * len(iterator_list_value))
|
||||
|
||||
# Determine the number of parallel workers
|
||||
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
future_to_index: dict[Future[tuple[datetime, list[GraphNodeEventBase], object | None, int]], int] = {}
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
yield IterationNextEvent(index=index)
|
||||
future = executor.submit(
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
# Process completed iterations as they finish
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
result = future.result()
|
||||
iter_start_at, events, output_value, tokens_used = result
|
||||
|
||||
# Update outputs at the correct index
|
||||
outputs[index] = output_value
|
||||
|
||||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
self.graph_runtime_state.total_tokens += tokens_used
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
# Cancel remaining futures and re-raise
|
||||
for f in future_to_index:
|
||||
if f != future:
|
||||
f.cancel()
|
||||
raise IterationNodeError(str(e))
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
outputs[index] = None
|
||||
case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[index] = None # Will be filtered later
|
||||
|
||||
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[:] = [output for output in outputs if output is not None]
|
||||
|
||||
def _execute_single_iteration_parallel(
|
||||
self,
|
||||
index: int,
|
||||
item: Any,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Collect events instead of yielding them directly
|
||||
for event in self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs_temp,
|
||||
graph_engine=graph_engine,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Get the output value from the temporary outputs list
|
||||
output_value = outputs_temp[0] if outputs_temp else None
|
||||
|
||||
return iter_start_at, events, output_value, graph_engine.graph_runtime_state.total_tokens
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Reference in New Issue