mirror of https://github.com/langgenius/dify.git
refactor: extract _run method into smaller focused methods in IterationNode
- Extract iterator variable retrieval and validation logic - Separate empty iteration handling - Create dedicated methods for iteration execution and result handling - Improve type hints and use modern Python syntax - Enhance code readability and maintainability
This commit is contained in:
parent
1c9f40f92a
commit
e0e82fbfaa
|
|
@ -1,7 +1,9 @@
|
|||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
|
|
@ -23,6 +25,7 @@ from core.workflow.node_events import (
|
|||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
IterationSucceededEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
|
|
@ -45,6 +48,7 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
class IterationNode(Node):
|
||||
"""
|
||||
|
|
@ -58,7 +62,7 @@ class IterationNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -67,7 +71,7 @@ class IterationNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -77,7 +81,7 @@ class IterationNode(Node):
|
|||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
def get_default_config(cls, filters: dict[str, object] | None = None):
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
|
|
@ -91,40 +95,17 @@ class IterationNode(Node):
|
|||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
variable = self._get_iterator_variable()
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
|
||||
# Try our best to preserve the type informat.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
if self._is_empty_iteration(variable):
|
||||
yield from self._handle_empty_iteration(variable)
|
||||
return
|
||||
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
iterator_list_value = self._validate_and_get_iterator_list(variable)
|
||||
inputs = {"iterator_selector": iterator_list_value}
|
||||
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
self._validate_start_node()
|
||||
|
||||
started_at = naive_utc_now()
|
||||
iter_run_map: dict[str, float] = {}
|
||||
|
|
@ -137,62 +118,150 @@ class IterationNode(Node):
|
|||
)
|
||||
|
||||
try:
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield IterationNextEvent(index=index)
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Run the iteration
|
||||
yield from self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs,
|
||||
graph_engine=graph_engine,
|
||||
)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
yield from self._execute_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
outputs=outputs,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
|
||||
# Yield final success event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
yield from self._handle_iteration_success(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
except IterationNodeError as e:
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
yield from self._handle_iteration_failure(
|
||||
started_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
iterator_list_value=iterator_list_value,
|
||||
iter_run_map=iter_run_map,
|
||||
error=e,
|
||||
)
|
||||
|
||||
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
return variable
|
||||
|
||||
def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]:
|
||||
return isinstance(variable, NoneSegment) or len(variable.value) == 0
|
||||
|
||||
def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]:
|
||||
# Try our best to preserve the type information.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
else:
|
||||
output = ArrayAnySegment(value=[])
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
# TODO(QuantumGhost): is it possible to compute the type of `output`
|
||||
# from graph definition?
|
||||
outputs={"output": output},
|
||||
)
|
||||
)
|
||||
|
||||
def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]:
|
||||
iterator_list_value = variable.to_object()
|
||||
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
return cast(list[object], iterator_list_value)
|
||||
|
||||
def _validate_start_node(self) -> None:
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
def _execute_iterations(
|
||||
self,
|
||||
iterator_list_value: Sequence[object],
|
||||
outputs: list[Any],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
yield IterationNextEvent(index=index)
|
||||
|
||||
graph_engine = self._create_graph_engine(index, item)
|
||||
|
||||
# Run the iteration
|
||||
yield from self._run_single_iter(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool,
|
||||
outputs=outputs,
|
||||
graph_engine=graph_engine,
|
||||
)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[Any],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
)
|
||||
|
||||
# Yield final success event
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_iteration_failure(
|
||||
self,
|
||||
started_at: datetime,
|
||||
inputs: dict[str, Sequence[object]],
|
||||
outputs: list[Any],
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
|
||||
},
|
||||
error=str(error),
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
|
|
@ -305,9 +374,9 @@ class IterationNode(Node):
|
|||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
outputs: list,
|
||||
outputs: list[object],
|
||||
graph_engine: "GraphEngine",
|
||||
) -> Generator[Union[GraphNodeEventBase, StreamCompletedEvent], None, None]:
|
||||
) -> Generator[GraphNodeEventBase, None, None]:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
index_variable = variable_pool.get([self._node_id, "index"])
|
||||
|
|
|
|||
Loading…
Reference in New Issue