From e0e82fbfaa8116f188312171196a703ade4aa96c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:15:36 +0800 Subject: [PATCH] refactor: extract _run method into smaller focused methods in IterationNode - Extract iterator variable retrieval and validation logic - Separate empty iteration handling - Create dedicated methods for iteration execution and result handling - Improve type hints and use modern Python syntax - Enhance code readability and maintainability --- .../nodes/iteration/iteration_node.py | 229 ++++++++++++------ 1 file changed, 149 insertions(+), 80 deletions(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index f15730d105..6aaa432ca7 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,7 +1,9 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, NewType, cast + +from typing_extensions import TypeIs from core.variables import IntegerVariable, NoneSegment from core.variables.segments import ArrayAnySegment, ArraySegment @@ -23,6 +25,7 @@ from core.workflow.node_events import ( IterationNextEvent, IterationStartedEvent, IterationSucceededEvent, + NodeEventBase, NodeRunResult, StreamCompletedEvent, ) @@ -45,6 +48,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) class IterationNode(Node): """ @@ -58,7 +62,7 @@ class IterationNode(Node): def init_node_data(self, data: Mapping[str, Any]): self._node_data = IterationNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -67,7 +71,7 @@ class IterationNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -77,7 +81,7 @@ class IterationNode(Node): return self._node_data @classmethod - def get_default_config(cls, filters: Optional[dict] = None): + def get_default_config(cls, filters: dict[str, object] | None = None): return { "type": "iteration", "config": { @@ -91,40 +95,17 @@ class IterationNode(Node): def version(cls) -> str: return "1" - def _run(self) -> Generator: - variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) + def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # pyright: ignore[reportIncompatibleMethodOverride] + variable = self._get_iterator_variable() - if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") - - if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): - raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") - - if isinstance(variable, NoneSegment) or len(variable.value) == 0: - # Try our best to preserve the type informat. - if isinstance(variable, ArraySegment): - output = variable.model_copy(update={"value": []}) - else: - output = ArrayAnySegment(value=[]) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - # TODO(QuantumGhost): is it possible to compute the type of `output` - # from graph definition? - outputs={"output": output}, - ) - ) + if self._is_empty_iteration(variable): + yield from self._handle_empty_iteration(variable) return - iterator_list_value = variable.to_object() - - if not isinstance(iterator_list_value, list): - raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - + iterator_list_value = self._validate_and_get_iterator_list(variable) inputs = {"iterator_selector": iterator_list_value} - if not self._node_data.start_node_id: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") + self._validate_start_node() started_at = naive_utc_now() iter_run_map: dict[str, float] = {} @@ -137,62 +118,150 @@ class IterationNode(Node): ) try: - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Update the total tokens from this iteration - self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - yield IterationSucceededEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, + yield from self._execute_iterations( + iterator_list_value=iterator_list_value, + outputs=outputs, + iter_run_map=iter_run_map, ) - # Yield final success event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": outputs}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, - }, - ) + yield from self._handle_iteration_success( + started_at=started_at, + inputs=inputs, + outputs=outputs, + iterator_list_value=iterator_list_value, + iter_run_map=iter_run_map, ) except IterationNodeError as e: - yield IterationFailedEvent( - start_at=started_at, + yield from self._handle_iteration_failure( + started_at=started_at, inputs=inputs, + outputs=outputs, + iterator_list_value=iterator_list_value, + iter_run_map=iter_run_map, + error=e, + ) + + def _get_iterator_variable(self) -> ArraySegment | NoneSegment: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) + + if not variable: + raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") + + if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): + raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") + + return variable + + def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]: + return isinstance(variable, NoneSegment) or len(variable.value) == 0 + + def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]: + # Try our best to preserve the type information. + if isinstance(variable, ArraySegment): + output = variable.model_copy(update={"value": []}) + else: + output = ArrayAnySegment(value=[]) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + # TODO(QuantumGhost): is it possible to compute the type of `output` + # from graph definition? + outputs={"output": output}, + ) + ) + + def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]: + iterator_list_value = variable.to_object() + + if not isinstance(iterator_list_value, list): + raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") + + return cast(list[object], iterator_list_value) + + def _validate_start_node(self) -> None: + if not self._node_data.start_node_id: + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") + + def _execute_iterations( + self, + iterator_list_value: Sequence[object], + outputs: list[Any], + iter_run_map: dict[str, float], + ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: + for index, item in enumerate(iterator_list_value): + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + yield IterationNextEvent(index=index) + + graph_engine = self._create_graph_engine(index, item) + + # Run the iteration + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, + ) + + # Update the total tokens from this iteration + self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens + iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + + def _handle_iteration_success( + self, + started_at: datetime, + inputs: dict[str, Sequence[object]], + outputs: list[Any], + iterator_list_value: Sequence[object], + iter_run_map: dict[str, float], + ) -> Generator[NodeEventBase, None, None]: + yield IterationSucceededEvent( + start_at=started_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, + }, + ) + + # Yield final success event + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": outputs}, - steps=len(iterator_list_value), metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, }, - error=str(e), ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) + ) + + def _handle_iteration_failure( + self, + started_at: datetime, + inputs: dict[str, Sequence[object]], + outputs: list[Any], + iterator_list_value: Sequence[object], + iter_run_map: dict[str, float], + error: IterationNodeError, + ) -> Generator[NodeEventBase, None, None]: + yield IterationFailedEvent( + start_at=started_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, + }, + error=str(error), + ) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(error), ) + ) @classmethod def _extract_variable_selector_to_variable_mapping( @@ -305,9 +374,9 @@ class IterationNode(Node): self, *, variable_pool: VariablePool, - outputs: list, + outputs: list[object], graph_engine: "GraphEngine", - ) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]: + ) -> Generator[GraphNodeEventBase, None, None]: rst = graph_engine.run() # get current iteration index index_variable = variable_pool.get([self._node_id, "index"])