fix(api): avoid recursive loop type adapters

This commit is contained in:
Yanli 盐粒 2026-03-18 18:20:43 +08:00
parent c5920fb28a
commit 9a86f280eb

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from enum import StrEnum
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, Any, Literal, TypeAlias, cast
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
@ -12,12 +12,10 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"]
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue)
_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset(
@ -40,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
return seg_type
def _validate_loop_value(value: object) -> LoopValue:
if value is None or isinstance(value, (str, int, float, bool)):
return cast(LoopValue, value)
if isinstance(value, list):
return [_validate_loop_value(item) for item in value]
if isinstance(value, dict):
normalized: dict[str, LoopValue] = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop values only support string object keys")
normalized[key] = _validate_loop_value(item)
return normalized
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
if not isinstance(value, dict):
raise TypeError("Loop outputs must be an object")
normalized: LoopValueMapping = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop output keys must be strings")
normalized[key] = _validate_loop_value(item)
return normalized
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
@ -59,7 +87,7 @@ class LoopVariableData(BaseModel):
return None
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _LOOP_VALUE_ADAPTER.validate_python(value)
return _validate_loop_value(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
@ -70,7 +98,7 @@ class LoopVariableData(BaseModel):
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _LOOP_VALUE_ADAPTER.validate_python(self.value)
return _validate_loop_value(self.value)
class LoopNodeData(BaseLoopNodeData):
@ -86,7 +114,7 @@ class LoopNodeData(BaseLoopNodeData):
def validate_outputs(cls, value: object) -> LoopValueMapping:
if value is None:
return {}
return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value)
return _validate_loop_value_mapping(value)
class LoopStartNodeData(BaseNodeData):