From 0d805e624e402c26e3a3fc27a8b45a50b3d70202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yanli=20=E7=9B=90=E7=B2=92?= Date: Tue, 17 Mar 2026 19:39:54 +0800 Subject: [PATCH] Type phase 3 loop values --- api/dify_graph/nodes/loop/entities.py | 57 ++++++++++++++++++++------ api/dify_graph/nodes/loop/loop_node.py | 57 +++++++++++++++++--------- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/api/dify_graph/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py index f0bfad5a0f..2fddb07ed8 100644 --- a/api/dify_graph/nodes/loop/entities.py +++ b/api/dify_graph/nodes/loop/entities.py @@ -1,7 +1,10 @@ -from enum import StrEnum -from typing import Annotated, Any, Literal +from __future__ import annotations -from pydantic import AfterValidator, BaseModel, Field, field_validator +from enum import StrEnum +from typing import Annotated, Literal, TypeAlias + +from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator +from pydantic_core.core_schema import ValidationInfo from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType @@ -9,6 +12,14 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState from dify_graph.utils.condition.entities import Condition from dify_graph.variables.types import SegmentType +LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"] +LoopValueMapping: TypeAlias = dict[str, LoopValue] +VariableSelector: TypeAlias = list[str] + +_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue) +_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping) +_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector) + _VALID_VAR_TYPE = frozenset( [ SegmentType.STRING, @@ -37,7 +48,29 @@ class LoopVariableData(BaseModel): label: str var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] - value: Any | list[str] | None = None + value: LoopValue | VariableSelector | None = None + + @field_validator("value", mode="before") + @classmethod + def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None: + value_type = validation_info.data.get("value_type") + if value_type == "variable": + if value is None: + return None + return _VARIABLE_SELECTOR_ADAPTER.validate_python(value) + if value_type == "constant": + return _LOOP_VALUE_ADAPTER.validate_python(value) + raise ValueError(f"Unknown loop variable value type: {value_type}") + + def require_variable_selector(self) -> VariableSelector: + if self.value_type != "variable": + raise ValueError(f"Expected variable loop input, got {self.value_type}") + return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value) + + def require_constant_value(self) -> LoopValue: + if self.value_type != "constant": + raise ValueError(f"Expected constant loop input, got {self.value_type}") + return _LOOP_VALUE_ADAPTER.validate_python(self.value) class LoopNodeData(BaseLoopNodeData): @@ -46,14 +79,14 @@ class LoopNodeData(BaseLoopNodeData): break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) - outputs: dict[str, Any] = Field(default_factory=dict) + outputs: LoopValueMapping = Field(default_factory=dict) @field_validator("outputs", mode="before") @classmethod - def validate_outputs(cls, v): - if v is None: + def validate_outputs(cls, value: object) -> LoopValueMapping: + if value is None: return {} - return v + return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value) class LoopStartNodeData(BaseNodeData): @@ -77,8 +110,8 @@ class LoopState(BaseLoopState): Loop State. """ - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None + outputs: list[LoopValue] = Field(default_factory=list) + current_output: LoopValue | None = None class MetaData(BaseLoopState.MetaData): """ @@ -87,7 +120,7 @@ class LoopState(BaseLoopState): loop_length: int - def get_last_output(self) -> Any: + def get_last_output(self) -> LoopValue | None: """ Get last output. """ @@ -95,7 +128,7 @@ class LoopState(BaseLoopState): return self.outputs[-1] return None - def get_current_output(self) -> Any: + def get_current_output(self) -> LoopValue | None: """ Get current output. """ diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 3c546ffa23..a2d827f034 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -3,7 +3,7 @@ import json import logging from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Literal, cast from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( @@ -29,7 +29,7 @@ from dify_graph.node_events import ( ) from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData +from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData from dify_graph.utils.condition.processor import ConditionProcessor from dify_graph.variables import Segment, SegmentType from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable @@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): break_conditions = self.node_data.break_conditions logical_operator = self.node_data.logical_operator - inputs = {"loop_count": loop_count} + inputs: dict[str, object] = {"loop_count": loop_count} if not self.node_data.start_node_id: raise ValueError(f"field start_node_id in loop {self._node_id} not found") @@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): root_node_id = self.node_data.start_node_id # Initialize loop variables in the original variable pool - loop_variable_selectors = {} + loop_variable_selectors: dict[str, list[str]] = {} if self.node_data.loop_variables: value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { - "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), + "constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()), "variable": lambda var: ( - self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None + self.graph_runtime_state.variable_pool.get(var.require_variable_selector()) + if var.value is not None + else None ), } for loop_variable in self.node_data.loop_variables: @@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): condition_processor = ConditionProcessor() loop_duration_map: dict[str, float] = {} - single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output + single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output loop_usage = LLMUsage.empty_usage() loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id) @@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) # Collect loop variable values after iteration - single_loop_variable = {} + single_loop_variable: dict[str, LoopValue] = {} for key, selector in loop_variable_selectors.items(): segment = self.graph_runtime_state.variable_pool.get(selector) single_loop_variable[key] = segment.value if segment else None @@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): def _extract_variable_selector_to_variable_mapping( cls, *, - graph_config: Mapping[str, Any], + graph_config: Mapping[str, object], node_id: str, node_data: LoopNodeData, ) -> Mapping[str, Sequence[str]]: - variable_mapping = {} + variable_mapping: dict[str, Sequence[str]] = {} # Extract loop node IDs statically from graph_config loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) # Get node configs from graph_config - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} + raw_nodes = graph_config.get("nodes") + node_configs: dict[str, Mapping[str, object]] = {} + if isinstance(raw_nodes, list): + for raw_node in raw_nodes: + if not isinstance(raw_node, dict): + continue + raw_node_id = raw_node.get("id") + if isinstance(raw_node_id, str): + node_configs[raw_node_id] = raw_node for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("loop_id") != node_id: + sub_node_data = sub_node_config.get("data") + if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id: continue # variable selector to variable mapping @@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): for loop_variable in node_data.loop_variables or []: if loop_variable.value_type == "variable": - assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping - selector = loop_variable.value + selector = loop_variable.require_variable_selector() variable_mapping[f"{node_id}.{loop_variable.label}"] = selector # remove variable out from loop @@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return variable_mapping @classmethod - def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: + def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]: """ Extract node IDs that belong to a specific loop from graph configuration. @@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): :param loop_node_id: the ID of the loop node :return: set of node IDs that belong to the loop """ - loop_node_ids = set() + loop_node_ids: set[str] = set() # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_data = node.get("data", {}) + raw_nodes = graph_config.get("nodes") + if not isinstance(raw_nodes, list): + return loop_node_ids + + for node in raw_nodes: + if not isinstance(node, dict): + continue + node_data = node.get("data") + if not isinstance(node_data, dict): + continue if node_data.get("loop_id") == loop_node_id: node_id = node.get("id") if node_id: @@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return loop_node_ids @staticmethod - def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: + def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment: """Get the appropriate segment type for a constant value.""" # TODO: Refactor for maintainability: # 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)