This commit is contained in:
takatost 2024-07-31 02:27:23 +08:00
parent c9bb366e1a
commit 8d27ec364f
2 changed files with 1 additions and 56 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseNodeData

View File

@ -273,61 +273,6 @@ class IterationNode(BaseNode):
variable_pool.remove([self.node_id, 'index'])
variable_pool.remove([self.node_id, 'item'])
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.add((self.node_id, 'index'), state.index)
# get the iterator value
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.add((self.node_id, 'item'), iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_any(output_selector)
# clear the output for this iteration
variable_pool.remove([self.node_id] + output_selector[1:])
state.current_output = output
if output is not None:
# NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
if isinstance(output, list):
state.outputs.extend(output)
else:
state.outputs.append(output)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
"""