diff --git a/api/dify_graph/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py index 2fddb07ed8..2d24a30d3d 100644 --- a/api/dify_graph/nodes/loop/entities.py +++ b/api/dify_graph/nodes/loop/entities.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import StrEnum -from typing import Annotated, Literal, TypeAlias +from typing import Annotated, Any, Literal, TypeAlias, cast from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator from pydantic_core.core_schema import ValidationInfo @@ -12,12 +12,10 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState from dify_graph.utils.condition.entities import Condition from dify_graph.variables.types import SegmentType -LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"] +LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] LoopValueMapping: TypeAlias = dict[str, LoopValue] VariableSelector: TypeAlias = list[str] -_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue) -_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping) _VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector) _VALID_VAR_TYPE = frozenset( @@ -40,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: return seg_type +def _validate_loop_value(value: object) -> LoopValue: + if value is None or isinstance(value, (str, int, float, bool)): + return cast(LoopValue, value) + + if isinstance(value, list): + return [_validate_loop_value(item) for item in value] + + if isinstance(value, dict): + normalized: dict[str, LoopValue] = {} + for key, item in value.items(): + if not isinstance(key, str): + raise TypeError("Loop values only support string object keys") + normalized[key] = _validate_loop_value(item) + return normalized + + raise TypeError("Loop values must be JSON-like primitives, arrays, or objects") + + +def _validate_loop_value_mapping(value: object) -> LoopValueMapping: + if not isinstance(value, dict): + raise TypeError("Loop outputs must be an object") + + normalized: LoopValueMapping = {} + for key, item in value.items(): + if not isinstance(key, str): + raise TypeError("Loop output keys must be strings") + normalized[key] = _validate_loop_value(item) + return normalized + + class LoopVariableData(BaseModel): """ Loop Variable Data. @@ -59,7 +87,7 @@ class LoopVariableData(BaseModel): return None return _VARIABLE_SELECTOR_ADAPTER.validate_python(value) if value_type == "constant": - return _LOOP_VALUE_ADAPTER.validate_python(value) + return _validate_loop_value(value) raise ValueError(f"Unknown loop variable value type: {value_type}") def require_variable_selector(self) -> VariableSelector: @@ -70,7 +98,7 @@ class LoopVariableData(BaseModel): def require_constant_value(self) -> LoopValue: if self.value_type != "constant": raise ValueError(f"Expected constant loop input, got {self.value_type}") - return _LOOP_VALUE_ADAPTER.validate_python(self.value) + return _validate_loop_value(self.value) class LoopNodeData(BaseLoopNodeData): @@ -86,7 +114,7 @@ class LoopNodeData(BaseLoopNodeData): def validate_outputs(cls, value: object) -> LoopValueMapping: if value is None: return {} - return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value) + return _validate_loop_value_mapping(value) class LoopStartNodeData(BaseNodeData):