mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 17:47:30 +08:00
fix: fixed error when clear value of INTEGER and FLOAT type (#27954)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
7ca5e8812e
commit
96865ebc8c
@ -202,6 +202,35 @@ class SegmentType(StrEnum):
|
|||||||
raise ValueError(f"element_type is only supported by array type, got {self}")
|
raise ValueError(f"element_type is only supported by array type, got {self}")
|
||||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_zero_value(t: "SegmentType"):
|
||||||
|
# Lazy import to avoid circular dependency
|
||||||
|
from factories import variable_factory
|
||||||
|
|
||||||
|
match t:
|
||||||
|
case (
|
||||||
|
SegmentType.ARRAY_OBJECT
|
||||||
|
| SegmentType.ARRAY_ANY
|
||||||
|
| SegmentType.ARRAY_STRING
|
||||||
|
| SegmentType.ARRAY_NUMBER
|
||||||
|
| SegmentType.ARRAY_BOOLEAN
|
||||||
|
):
|
||||||
|
return variable_factory.build_segment_with_type(t, [])
|
||||||
|
case SegmentType.OBJECT:
|
||||||
|
return variable_factory.build_segment({})
|
||||||
|
case SegmentType.STRING:
|
||||||
|
return variable_factory.build_segment("")
|
||||||
|
case SegmentType.INTEGER:
|
||||||
|
return variable_factory.build_segment(0)
|
||||||
|
case SegmentType.FLOAT:
|
||||||
|
return variable_factory.build_segment(0.0)
|
||||||
|
case SegmentType.NUMBER:
|
||||||
|
return variable_factory.build_segment(0)
|
||||||
|
case SegmentType.BOOLEAN:
|
||||||
|
return variable_factory.build_segment(False)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"unsupported variable type: {t}")
|
||||||
|
|
||||||
|
|
||||||
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
|
||||||
# ARRAY_ANY does not have corresponding element type.
|
# ARRAY_ANY does not have corresponding element type.
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from collections.abc import Callable, Mapping, Sequence
|
|||||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||||
|
|
||||||
from core.variables import SegmentType, Variable
|
from core.variables import SegmentType, Variable
|
||||||
from core.variables.segments import BooleanSegment
|
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
@ -12,7 +11,6 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||||
from factories import variable_factory
|
|
||||||
|
|
||||||
from ..common.impl import conversation_variable_updater_factory
|
from ..common.impl import conversation_variable_updater_factory
|
||||||
from .node_data import VariableAssignerData, WriteMode
|
from .node_data import VariableAssignerData, WriteMode
|
||||||
@ -116,7 +114,7 @@ class VariableAssignerNode(Node):
|
|||||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||||
|
|
||||||
case WriteMode.CLEAR:
|
case WriteMode.CLEAR:
|
||||||
income_value = get_zero_value(original_variable.value_type)
|
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
||||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||||
|
|
||||||
# Over write the variable.
|
# Over write the variable.
|
||||||
@ -143,24 +141,3 @@ class VariableAssignerNode(Node):
|
|||||||
process_data=common_helpers.set_updated_variables({}, updated_variables),
|
process_data=common_helpers.set_updated_variables({}, updated_variables),
|
||||||
outputs={},
|
outputs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_zero_value(t: SegmentType):
|
|
||||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
|
||||||
match t:
|
|
||||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
|
||||||
return variable_factory.build_segment_with_type(t, [])
|
|
||||||
case SegmentType.OBJECT:
|
|
||||||
return variable_factory.build_segment({})
|
|
||||||
case SegmentType.STRING:
|
|
||||||
return variable_factory.build_segment("")
|
|
||||||
case SegmentType.INTEGER:
|
|
||||||
return variable_factory.build_segment(0)
|
|
||||||
case SegmentType.FLOAT:
|
|
||||||
return variable_factory.build_segment(0.0)
|
|
||||||
case SegmentType.NUMBER:
|
|
||||||
return variable_factory.build_segment(0)
|
|
||||||
case SegmentType.BOOLEAN:
|
|
||||||
return BooleanSegment(value=False)
|
|
||||||
case _:
|
|
||||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
|
||||||
|
|||||||
@ -1,14 +0,0 @@
|
|||||||
from core.variables import SegmentType
|
|
||||||
|
|
||||||
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
|
|
||||||
EMPTY_VALUE_MAPPING = {
|
|
||||||
SegmentType.STRING: "",
|
|
||||||
SegmentType.NUMBER: 0,
|
|
||||||
SegmentType.BOOLEAN: False,
|
|
||||||
SegmentType.OBJECT: {},
|
|
||||||
SegmentType.ARRAY_ANY: [],
|
|
||||||
SegmentType.ARRAY_STRING: [],
|
|
||||||
SegmentType.ARRAY_NUMBER: [],
|
|
||||||
SegmentType.ARRAY_OBJECT: [],
|
|
||||||
SegmentType.ARRAY_BOOLEAN: [],
|
|
||||||
}
|
|
||||||
@ -16,7 +16,6 @@ from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNod
|
|||||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||||
|
|
||||||
from . import helpers
|
from . import helpers
|
||||||
from .constants import EMPTY_VALUE_MAPPING
|
|
||||||
from .entities import VariableAssignerNodeData, VariableOperationItem
|
from .entities import VariableAssignerNodeData, VariableOperationItem
|
||||||
from .enums import InputType, Operation
|
from .enums import InputType, Operation
|
||||||
from .exc import (
|
from .exc import (
|
||||||
@ -249,7 +248,7 @@ class VariableAssignerNode(Node):
|
|||||||
case Operation.OVER_WRITE:
|
case Operation.OVER_WRITE:
|
||||||
return value
|
return value
|
||||||
case Operation.CLEAR:
|
case Operation.CLEAR:
|
||||||
return EMPTY_VALUE_MAPPING[variable.value_type]
|
return SegmentType.get_zero_value(variable.value_type).to_object()
|
||||||
case Operation.APPEND:
|
case Operation.APPEND:
|
||||||
return variable.value + [value]
|
return variable.value + [value]
|
||||||
case Operation.EXTEND:
|
case Operation.EXTEND:
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
from core.variables.types import ArrayValidation, SegmentType
|
from core.variables.types import ArrayValidation, SegmentType
|
||||||
|
|
||||||
|
|
||||||
@ -83,3 +85,81 @@ class TestSegmentTypeIsValidArrayValidation:
|
|||||||
value = [1, 2, 3]
|
value = [1, 2, 3]
|
||||||
# validation is None, skip
|
# validation is None, skip
|
||||||
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
|
assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSegmentTypeGetZeroValue:
|
||||||
|
"""
|
||||||
|
Test class for SegmentType.get_zero_value static method.
|
||||||
|
|
||||||
|
Provides comprehensive coverage of all supported SegmentType values to ensure
|
||||||
|
correct zero value generation for each type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_array_types_return_empty_list(self):
|
||||||
|
"""Test that all array types return empty list segments."""
|
||||||
|
array_types = [
|
||||||
|
SegmentType.ARRAY_ANY,
|
||||||
|
SegmentType.ARRAY_STRING,
|
||||||
|
SegmentType.ARRAY_NUMBER,
|
||||||
|
SegmentType.ARRAY_OBJECT,
|
||||||
|
SegmentType.ARRAY_BOOLEAN,
|
||||||
|
]
|
||||||
|
|
||||||
|
for seg_type in array_types:
|
||||||
|
result = SegmentType.get_zero_value(seg_type)
|
||||||
|
assert result.value == []
|
||||||
|
assert result.value_type == seg_type
|
||||||
|
|
||||||
|
def test_object_returns_empty_dict(self):
|
||||||
|
"""Test that OBJECT type returns empty dictionary segment."""
|
||||||
|
result = SegmentType.get_zero_value(SegmentType.OBJECT)
|
||||||
|
assert result.value == {}
|
||||||
|
assert result.value_type == SegmentType.OBJECT
|
||||||
|
|
||||||
|
def test_string_returns_empty_string(self):
|
||||||
|
"""Test that STRING type returns empty string segment."""
|
||||||
|
result = SegmentType.get_zero_value(SegmentType.STRING)
|
||||||
|
assert result.value == ""
|
||||||
|
assert result.value_type == SegmentType.STRING
|
||||||
|
|
||||||
|
def test_integer_returns_zero(self):
|
||||||
|
"""Test that INTEGER type returns zero segment."""
|
||||||
|
result = SegmentType.get_zero_value(SegmentType.INTEGER)
|
||||||
|
assert result.value == 0
|
||||||
|
assert result.value_type == SegmentType.INTEGER
|
||||||
|
|
||||||
|
def test_float_returns_zero_point_zero(self):
|
||||||
|
"""Test that FLOAT type returns 0.0 segment."""
|
||||||
|
result = SegmentType.get_zero_value(SegmentType.FLOAT)
|
||||||
|
assert result.value == 0.0
|
||||||
|
assert result.value_type == SegmentType.FLOAT
|
||||||
|
|
||||||
|
def test_number_returns_zero(self):
|
||||||
|
"""Test that NUMBER type returns zero segment."""
|
||||||
|
result = SegmentType.get_zero_value(SegmentType.NUMBER)
|
||||||
|
assert result.value == 0
|
||||||
|
# NUMBER type with integer value returns INTEGER segment type
|
||||||
|
# (NUMBER is a union type that can be INTEGER or FLOAT)
|
||||||
|
assert result.value_type == SegmentType.INTEGER
|
||||||
|
# Verify that exposed_type returns NUMBER for frontend compatibility
|
||||||
|
assert result.value_type.exposed_type() == SegmentType.NUMBER
|
||||||
|
|
||||||
|
def test_boolean_returns_false(self):
|
||||||
|
"""Test that BOOLEAN type returns False segment."""
|
||||||
|
result = SegmentType.get_zero_value(SegmentType.BOOLEAN)
|
||||||
|
assert result.value is False
|
||||||
|
assert result.value_type == SegmentType.BOOLEAN
|
||||||
|
|
||||||
|
def test_unsupported_types_raise_value_error(self):
|
||||||
|
"""Test that unsupported types raise ValueError."""
|
||||||
|
unsupported_types = [
|
||||||
|
SegmentType.SECRET,
|
||||||
|
SegmentType.FILE,
|
||||||
|
SegmentType.NONE,
|
||||||
|
SegmentType.GROUP,
|
||||||
|
SegmentType.ARRAY_FILE,
|
||||||
|
]
|
||||||
|
|
||||||
|
for seg_type in unsupported_types:
|
||||||
|
with pytest.raises(ValueError, match="unsupported variable type"):
|
||||||
|
SegmentType.get_zero_value(seg_type)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user