refactor: extract _run method into smaller focused methods in IterationNode

- Extract iterator variable retrieval and validation logic
- Separate empty iteration handling
- Create dedicated methods for iteration execution and result handling
- Improve type hints and use modern Python syntax
- Enhance code readability and maintainability
This commit is contained in:
-LAN- 2025-09-10 01:15:36 +08:00
parent 1c9f40f92a
commit e0e82fbfaa
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
1 changed files with 149 additions and 80 deletions

View File

@ -1,7 +1,9 @@
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, NewType, cast
from typing_extensions import TypeIs
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
@ -23,6 +25,7 @@ from core.workflow.node_events import (
IterationNextEvent,
IterationStartedEvent,
IterationSucceededEvent,
NodeEventBase,
NodeRunResult,
StreamCompletedEvent,
)
@ -45,6 +48,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node):
"""
@ -58,7 +62,7 @@ class IterationNode(Node):
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -67,7 +71,7 @@ class IterationNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -77,7 +81,7 @@ class IterationNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None):
def get_default_config(cls, filters: dict[str, object] | None = None):
return {
"type": "iteration",
"config": {
@ -91,40 +95,17 @@ class IterationNode(Node):
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # pyright: ignore[reportIncompatibleMethodOverride]
variable = self._get_iterator_variable()
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
if self._is_empty_iteration(variable):
yield from self._handle_empty_iteration(variable)
return
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
iterator_list_value = self._validate_and_get_iterator_list(variable)
inputs = {"iterator_selector": iterator_list_value}
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
self._validate_start_node()
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
@ -137,62 +118,150 @@ class IterationNode(Node):
)
try:
for index, item in enumerate(iterator_list_value):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationNextEvent(index=index)
graph_engine = self._create_graph_engine(index, item)
# Run the iteration
yield from self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs,
graph_engine=graph_engine,
)
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
yield from self._execute_iterations(
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
)
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
)
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
)
except IterationNodeError as e:
yield IterationFailedEvent(
start_at=started_at,
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
error=e,
)
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
return variable
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
return isinstance(variable, NoneSegment) or len(variable.value) == 0
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
# Try our best to preserve the type information.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
# TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
)
)
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
return cast(list[object], iterator_list_value)
def _validate_start_node(self) -> None:
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
def _execute_iterations(
self,
iterator_list_value: Sequence[object],
outputs: list[Any],
iter_run_map: dict[str, float],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
for index, item in enumerate(iterator_list_value):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationNextEvent(index=index)
graph_engine = self._create_graph_engine(index, item)
# Run the iteration
yield from self._run_single_iter(
variable_pool=graph_engine.graph_runtime_state.variable_pool,
outputs=outputs,
graph_engine=graph_engine,
)
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _handle_iteration_success(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[Any],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
) -> Generator[NodeEventBase, None, None]:
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
# Yield final success event
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(e),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
def _handle_iteration_failure(
self,
started_at: datetime,
inputs: dict[str, Sequence[object]],
outputs: list[Any],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
@ -305,9 +374,9 @@ class IterationNode(Node):
self,
*,
variable_pool: VariablePool,
outputs: list,
outputs: list[object],
graph_engine: "GraphEngine",
) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]:
) -> Generator[GraphNodeEventBase, None, None]:
rst = graph_engine.run()
# get current iteration index
index_variable = variable_pool.get([self._node_id, "index"])