mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 22:28:55 +08:00
Type phase 3 loop values
This commit is contained in:
parent
61196180b8
commit
0d805e624e
@ -1,7 +1,10 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
@ -9,6 +12,14 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"]
|
||||
LoopValueMapping: TypeAlias = dict[str, LoopValue]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue)
|
||||
_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
@ -37,7 +48,29 @@ class LoopVariableData(BaseModel):
|
||||
label: str
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Any | list[str] | None = None
|
||||
value: LoopValue | VariableSelector | None = None
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
|
||||
value_type = validation_info.data.get("value_type")
|
||||
if value_type == "variable":
|
||||
if value is None:
|
||||
return None
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if value_type == "constant":
|
||||
return _LOOP_VALUE_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown loop variable value type: {value_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.value_type != "variable":
|
||||
raise ValueError(f"Expected variable loop input, got {self.value_type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
def require_constant_value(self) -> LoopValue:
|
||||
if self.value_type != "constant":
|
||||
raise ValueError(f"Expected constant loop input, got {self.value_type}")
|
||||
return _LOOP_VALUE_ADAPTER.validate_python(self.value)
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
@ -46,14 +79,14 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
outputs: LoopValueMapping = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
@classmethod
|
||||
def validate_outputs(cls, v):
|
||||
if v is None:
|
||||
def validate_outputs(cls, value: object) -> LoopValueMapping:
|
||||
if value is None:
|
||||
return {}
|
||||
return v
|
||||
return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value)
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
@ -77,8 +110,8 @@ class LoopState(BaseLoopState):
|
||||
Loop State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any = None
|
||||
outputs: list[LoopValue] = Field(default_factory=list)
|
||||
current_output: LoopValue | None = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
@ -87,7 +120,7 @@ class LoopState(BaseLoopState):
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Any:
|
||||
def get_last_output(self) -> LoopValue | None:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@ -95,7 +128,7 @@ class LoopState(BaseLoopState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Any:
|
||||
def get_current_output(self) -> LoopValue | None:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
@ -29,7 +29,7 @@ from dify_graph.node_events import (
|
||||
)
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
|
||||
from dify_graph.utils.condition.processor import ConditionProcessor
|
||||
from dify_graph.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
inputs: dict[str, object] = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors = {}
|
||||
loop_variable_selectors: dict[str, list[str]] = {}
|
||||
if self.node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
|
||||
"variable": lambda var: (
|
||||
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
|
||||
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
|
||||
if var.value is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||
|
||||
@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
single_loop_variable: dict[str, LoopValue] = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
segment = self.graph_runtime_state.variable_pool.get(selector)
|
||||
single_loop_variable[key] = segment.value if segment else None
|
||||
@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: Mapping[str, object],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
variable_mapping = {}
|
||||
variable_mapping: dict[str, Sequence[str]] = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
node_configs: dict[str, Mapping[str, object]] = {}
|
||||
if isinstance(raw_nodes, list):
|
||||
for raw_node in raw_nodes:
|
||||
if not isinstance(raw_node, dict):
|
||||
continue
|
||||
raw_node_id = raw_node.get("id")
|
||||
if isinstance(raw_node_id, str):
|
||||
node_configs[raw_node_id] = raw_node
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
sub_node_data = sub_node_config.get("data")
|
||||
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
selector = loop_variable.value
|
||||
selector = loop_variable.require_variable_selector()
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
loop_node_ids: set[str] = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
if not isinstance(raw_nodes, list):
|
||||
return loop_node_ids
|
||||
|
||||
for node in raw_nodes:
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
node_data = node.get("data")
|
||||
if not isinstance(node_data, dict):
|
||||
continue
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
# TODO: Refactor for maintainability:
|
||||
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user