mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 22:28:55 +08:00
fix(api): avoid recursive loop type adapters
This commit is contained in:
parent
c5920fb28a
commit
9a86f280eb
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from typing import Annotated, Any, Literal, TypeAlias, cast
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
@ -12,12 +12,10 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"]
|
||||
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
LoopValueMapping: TypeAlias = dict[str, LoopValue]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue)
|
||||
_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
@ -40,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||
return seg_type
|
||||
|
||||
|
||||
def _validate_loop_value(value: object) -> LoopValue:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return cast(LoopValue, value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_validate_loop_value(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
normalized: dict[str, LoopValue] = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("Loop values only support string object keys")
|
||||
normalized[key] = _validate_loop_value(item)
|
||||
return normalized
|
||||
|
||||
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
|
||||
|
||||
|
||||
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("Loop outputs must be an object")
|
||||
|
||||
normalized: LoopValueMapping = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("Loop output keys must be strings")
|
||||
normalized[key] = _validate_loop_value(item)
|
||||
return normalized
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
Loop Variable Data.
|
||||
@ -59,7 +87,7 @@ class LoopVariableData(BaseModel):
|
||||
return None
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if value_type == "constant":
|
||||
return _LOOP_VALUE_ADAPTER.validate_python(value)
|
||||
return _validate_loop_value(value)
|
||||
raise ValueError(f"Unknown loop variable value type: {value_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
@ -70,7 +98,7 @@ class LoopVariableData(BaseModel):
|
||||
def require_constant_value(self) -> LoopValue:
|
||||
if self.value_type != "constant":
|
||||
raise ValueError(f"Expected constant loop input, got {self.value_type}")
|
||||
return _LOOP_VALUE_ADAPTER.validate_python(self.value)
|
||||
return _validate_loop_value(self.value)
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
@ -86,7 +114,7 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
def validate_outputs(cls, value: object) -> LoopValueMapping:
|
||||
if value is None:
|
||||
return {}
|
||||
return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value)
|
||||
return _validate_loop_value_mapping(value)
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user