Type phase 3 loop values

This commit is contained in:
Yanli 盐粒 2026-03-17 19:39:54 +08:00
parent 61196180b8
commit 0d805e624e
2 changed files with 82 additions and 32 deletions

View File

@ -1,7 +1,10 @@
from enum import StrEnum
from typing import Annotated, Any, Literal
from __future__ import annotations
from pydantic import AfterValidator, BaseModel, Field, field_validator
from enum import StrEnum
from typing import Annotated, Literal, TypeAlias
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
@ -9,6 +12,14 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue)
_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
@ -37,7 +48,29 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Any | list[str] | None = None
value: LoopValue | VariableSelector | None = None
@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
value_type = validation_info.data.get("value_type")
if value_type == "variable":
if value is None:
return None
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _LOOP_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
if self.value_type != "variable":
raise ValueError(f"Expected variable loop input, got {self.value_type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _LOOP_VALUE_ADAPTER.validate_python(self.value)
class LoopNodeData(BaseLoopNodeData):
@ -46,14 +79,14 @@ class LoopNodeData(BaseLoopNodeData):
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
outputs: LoopValueMapping = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
def validate_outputs(cls, value: object) -> LoopValueMapping:
if value is None:
return {}
return v
return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value)
class LoopStartNodeData(BaseNodeData):
@ -77,8 +110,8 @@ class LoopState(BaseLoopState):
Loop State.
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any = None
outputs: list[LoopValue] = Field(default_factory=list)
current_output: LoopValue | None = None
class MetaData(BaseLoopState.MetaData):
"""
@ -87,7 +120,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Any:
def get_last_output(self) -> LoopValue | None:
"""
Get last output.
"""
@ -95,7 +128,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any:
def get_current_output(self) -> LoopValue | None:
"""
Get current output.
"""

View File

@ -3,7 +3,7 @@ import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
@ -29,7 +29,7 @@ from dify_graph.node_events import (
)
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
from dify_graph.utils.condition.processor import ConditionProcessor
from dify_graph.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count}
inputs: dict[str, object] = {"loop_count": loop_count}
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
loop_variable_selectors: dict[str, list[str]] = {}
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
"variable": lambda var: (
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
if var.value is not None
else None
),
}
for loop_variable in self.node_data.loop_variables:
@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
single_loop_variable: dict[str, LoopValue] = {}
for key, selector in loop_variable_selectors.items():
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
graph_config: Mapping[str, object],
node_id: str,
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
variable_mapping = {}
variable_mapping: dict[str, Sequence[str]] = {}
# Extract loop node IDs statically from graph_config
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
raw_nodes = graph_config.get("nodes")
node_configs: dict[str, Mapping[str, object]] = {}
if isinstance(raw_nodes, list):
for raw_node in raw_nodes:
if not isinstance(raw_node, dict):
continue
raw_node_id = raw_node.get("id")
if isinstance(raw_node_id, str):
node_configs[raw_node_id] = raw_node
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
sub_node_data = sub_node_config.get("data")
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
continue
# variable selector to variable mapping
@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.value
selector = loop_variable.require_variable_selector()
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
loop_node_ids: set[str] = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
raw_nodes = graph_config.get("nodes")
if not isinstance(raw_nodes, list):
return loop_node_ids
for node in raw_nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data")
if not isinstance(node_data, dict):
continue
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
"""Get the appropriate segment type for a constant value."""
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)