add unit tests for iteration node (#28719)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Satoshi Dev 2025-11-26 18:36:47 -08:00 committed by GitHub
parent 766e16b26f
commit 5815950092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 729 additions and 0 deletions

View File

@ -0,0 +1,339 @@
from core.workflow.nodes.iteration.entities import (
ErrorHandleMode,
IterationNodeData,
IterationStartNodeData,
IterationState,
)
class TestErrorHandleMode:
"""Test suite for ErrorHandleMode enum."""
def test_terminated_value(self):
"""Test TERMINATED enum value."""
assert ErrorHandleMode.TERMINATED == "terminated"
assert ErrorHandleMode.TERMINATED.value == "terminated"
def test_continue_on_error_value(self):
"""Test CONTINUE_ON_ERROR enum value."""
assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error"
def test_remove_abnormal_output_value(self):
"""Test REMOVE_ABNORMAL_OUTPUT enum value."""
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output"
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output"
def test_error_handle_mode_is_str_enum(self):
"""Test ErrorHandleMode is a string enum."""
assert isinstance(ErrorHandleMode.TERMINATED, str)
assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str)
assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str)
def test_error_handle_mode_comparison(self):
"""Test ErrorHandleMode can be compared with strings."""
assert ErrorHandleMode.TERMINATED == "terminated"
assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error"
def test_all_error_handle_modes(self):
"""Test all ErrorHandleMode values are accessible."""
modes = list(ErrorHandleMode)
assert len(modes) == 3
assert ErrorHandleMode.TERMINATED in modes
assert ErrorHandleMode.CONTINUE_ON_ERROR in modes
assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes
class TestIterationNodeData:
"""Test suite for IterationNodeData model."""
def test_iteration_node_data_basic(self):
"""Test IterationNodeData with basic configuration."""
data = IterationNodeData(
title="Test Iteration",
iterator_selector=["node1", "output"],
output_selector=["iteration", "result"],
)
assert data.title == "Test Iteration"
assert data.iterator_selector == ["node1", "output"]
assert data.output_selector == ["iteration", "result"]
def test_iteration_node_data_default_values(self):
"""Test IterationNodeData default values."""
data = IterationNodeData(
title="Default Test",
iterator_selector=["start", "items"],
output_selector=["iter", "out"],
)
assert data.parent_loop_id is None
assert data.is_parallel is False
assert data.parallel_nums == 10
assert data.error_handle_mode == ErrorHandleMode.TERMINATED
assert data.flatten_output is True
def test_iteration_node_data_parallel_mode(self):
"""Test IterationNodeData with parallel mode enabled."""
data = IterationNodeData(
title="Parallel Iteration",
iterator_selector=["node", "list"],
output_selector=["iter", "output"],
is_parallel=True,
parallel_nums=5,
)
assert data.is_parallel is True
assert data.parallel_nums == 5
def test_iteration_node_data_custom_parallel_nums(self):
"""Test IterationNodeData with custom parallel numbers."""
data = IterationNodeData(
title="Custom Parallel",
iterator_selector=["a", "b"],
output_selector=["c", "d"],
parallel_nums=20,
)
assert data.parallel_nums == 20
def test_iteration_node_data_continue_on_error(self):
"""Test IterationNodeData with continue on error mode."""
data = IterationNodeData(
title="Continue Error",
iterator_selector=["x", "y"],
output_selector=["z", "w"],
error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
)
assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
def test_iteration_node_data_remove_abnormal_output(self):
"""Test IterationNodeData with remove abnormal output mode."""
data = IterationNodeData(
title="Remove Abnormal",
iterator_selector=["input", "array"],
output_selector=["output", "result"],
error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
)
assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
def test_iteration_node_data_flatten_output_disabled(self):
"""Test IterationNodeData with flatten output disabled."""
data = IterationNodeData(
title="No Flatten",
iterator_selector=["a"],
output_selector=["b"],
flatten_output=False,
)
assert data.flatten_output is False
def test_iteration_node_data_with_parent_loop_id(self):
"""Test IterationNodeData with parent loop ID."""
data = IterationNodeData(
title="Nested Loop",
iterator_selector=["parent", "items"],
output_selector=["child", "output"],
parent_loop_id="parent_loop_123",
)
assert data.parent_loop_id == "parent_loop_123"
def test_iteration_node_data_complex_selectors(self):
"""Test IterationNodeData with complex selectors."""
data = IterationNodeData(
title="Complex Selectors",
iterator_selector=["node1", "output", "data", "items"],
output_selector=["iteration", "result", "value"],
)
assert len(data.iterator_selector) == 4
assert len(data.output_selector) == 3
def test_iteration_node_data_all_options(self):
"""Test IterationNodeData with all options configured."""
data = IterationNodeData(
title="Full Config",
iterator_selector=["start", "list"],
output_selector=["end", "result"],
parent_loop_id="outer_loop",
is_parallel=True,
parallel_nums=15,
error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR,
flatten_output=False,
)
assert data.title == "Full Config"
assert data.parent_loop_id == "outer_loop"
assert data.is_parallel is True
assert data.parallel_nums == 15
assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
assert data.flatten_output is False
class TestIterationStartNodeData:
"""Test suite for IterationStartNodeData model."""
def test_iteration_start_node_data_basic(self):
"""Test IterationStartNodeData basic creation."""
data = IterationStartNodeData(title="Iteration Start")
assert data.title == "Iteration Start"
def test_iteration_start_node_data_with_description(self):
"""Test IterationStartNodeData with description."""
data = IterationStartNodeData(
title="Start Node",
desc="This is the start of iteration",
)
assert data.title == "Start Node"
assert data.desc == "This is the start of iteration"
class TestIterationState:
"""Test suite for IterationState model."""
def test_iteration_state_default_values(self):
"""Test IterationState default values."""
state = IterationState()
assert state.outputs == []
assert state.current_output is None
def test_iteration_state_with_outputs(self):
"""Test IterationState with outputs."""
state = IterationState(outputs=["result1", "result2", "result3"])
assert len(state.outputs) == 3
assert state.outputs[0] == "result1"
assert state.outputs[2] == "result3"
def test_iteration_state_with_current_output(self):
"""Test IterationState with current output."""
state = IterationState(current_output="current_value")
assert state.current_output == "current_value"
def test_iteration_state_get_last_output_with_outputs(self):
"""Test get_last_output with outputs present."""
state = IterationState(outputs=["first", "second", "last"])
result = state.get_last_output()
assert result == "last"
def test_iteration_state_get_last_output_empty(self):
"""Test get_last_output with empty outputs."""
state = IterationState(outputs=[])
result = state.get_last_output()
assert result is None
def test_iteration_state_get_last_output_single(self):
"""Test get_last_output with single output."""
state = IterationState(outputs=["only_one"])
result = state.get_last_output()
assert result == "only_one"
def test_iteration_state_get_current_output(self):
"""Test get_current_output method."""
state = IterationState(current_output={"key": "value"})
result = state.get_current_output()
assert result == {"key": "value"}
def test_iteration_state_get_current_output_none(self):
"""Test get_current_output when None."""
state = IterationState()
result = state.get_current_output()
assert result is None
def test_iteration_state_with_complex_outputs(self):
"""Test IterationState with complex output types."""
state = IterationState(
outputs=[
{"id": 1, "name": "first"},
{"id": 2, "name": "second"},
[1, 2, 3],
"string_output",
]
)
assert len(state.outputs) == 4
assert state.outputs[0] == {"id": 1, "name": "first"}
assert state.outputs[2] == [1, 2, 3]
def test_iteration_state_with_none_outputs(self):
"""Test IterationState with None values in outputs."""
state = IterationState(outputs=["value1", None, "value3"])
assert len(state.outputs) == 3
assert state.outputs[1] is None
def test_iteration_state_get_last_output_with_none(self):
"""Test get_last_output when last output is None."""
state = IterationState(outputs=["first", None])
result = state.get_last_output()
assert result is None
def test_iteration_state_metadata_class(self):
"""Test IterationState.MetaData class."""
metadata = IterationState.MetaData(iterator_length=10)
assert metadata.iterator_length == 10
def test_iteration_state_metadata_different_lengths(self):
"""Test IterationState.MetaData with different lengths."""
metadata1 = IterationState.MetaData(iterator_length=0)
metadata2 = IterationState.MetaData(iterator_length=100)
metadata3 = IterationState.MetaData(iterator_length=1000000)
assert metadata1.iterator_length == 0
assert metadata2.iterator_length == 100
assert metadata3.iterator_length == 1000000
def test_iteration_state_outputs_modification(self):
"""Test modifying IterationState outputs."""
state = IterationState(outputs=[])
state.outputs.append("new_output")
state.outputs.append("another_output")
assert len(state.outputs) == 2
assert state.get_last_output() == "another_output"
def test_iteration_state_current_output_update(self):
"""Test updating current_output."""
state = IterationState()
state.current_output = "first_value"
assert state.get_current_output() == "first_value"
state.current_output = "updated_value"
assert state.get_current_output() == "updated_value"
def test_iteration_state_with_numeric_outputs(self):
"""Test IterationState with numeric outputs."""
state = IterationState(outputs=[1, 2, 3, 4, 5])
assert state.get_last_output() == 5
assert len(state.outputs) == 5
def test_iteration_state_with_boolean_outputs(self):
"""Test IterationState with boolean outputs."""
state = IterationState(outputs=[True, False, True])
assert state.get_last_output() is True
assert state.outputs[1] is False

View File

@ -0,0 +1,390 @@
from core.workflow.enums import NodeType
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.nodes.iteration.exc import (
InvalidIteratorValueError,
IterationGraphNotFoundError,
IterationIndexNotFoundError,
IterationNodeError,
IteratorVariableNotFoundError,
StartNodeIdNotFoundError,
)
from core.workflow.nodes.iteration.iteration_node import IterationNode
class TestIterationNodeExceptions:
"""Test suite for iteration node exceptions."""
def test_iteration_node_error_is_value_error(self):
"""Test IterationNodeError inherits from ValueError."""
error = IterationNodeError("test error")
assert isinstance(error, ValueError)
assert str(error) == "test error"
def test_iterator_variable_not_found_error(self):
"""Test IteratorVariableNotFoundError."""
error = IteratorVariableNotFoundError("Iterator variable not found")
assert isinstance(error, IterationNodeError)
assert isinstance(error, ValueError)
assert "Iterator variable not found" in str(error)
def test_invalid_iterator_value_error(self):
"""Test InvalidIteratorValueError."""
error = InvalidIteratorValueError("Invalid iterator value")
assert isinstance(error, IterationNodeError)
assert "Invalid iterator value" in str(error)
def test_start_node_id_not_found_error(self):
"""Test StartNodeIdNotFoundError."""
error = StartNodeIdNotFoundError("Start node ID not found")
assert isinstance(error, IterationNodeError)
assert "Start node ID not found" in str(error)
def test_iteration_graph_not_found_error(self):
"""Test IterationGraphNotFoundError."""
error = IterationGraphNotFoundError("Iteration graph not found")
assert isinstance(error, IterationNodeError)
assert "Iteration graph not found" in str(error)
def test_iteration_index_not_found_error(self):
"""Test IterationIndexNotFoundError."""
error = IterationIndexNotFoundError("Iteration index not found")
assert isinstance(error, IterationNodeError)
assert "Iteration index not found" in str(error)
def test_exception_with_empty_message(self):
"""Test exception with empty message."""
error = IterationNodeError("")
assert str(error) == ""
def test_exception_with_detailed_message(self):
"""Test exception with detailed message."""
error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'")
assert "items" in str(error)
assert "start_node" in str(error)
def test_all_exceptions_inherit_from_base(self):
"""Test all exceptions inherit from IterationNodeError."""
exceptions = [
IteratorVariableNotFoundError("test"),
InvalidIteratorValueError("test"),
StartNodeIdNotFoundError("test"),
IterationGraphNotFoundError("test"),
IterationIndexNotFoundError("test"),
]
for exc in exceptions:
assert isinstance(exc, IterationNodeError)
assert isinstance(exc, ValueError)
class TestIterationNodeClassAttributes:
"""Test suite for IterationNode class attributes."""
def test_node_type(self):
"""Test IterationNode node_type attribute."""
assert IterationNode.node_type == NodeType.ITERATION
def test_version(self):
"""Test IterationNode version method."""
version = IterationNode.version()
assert version == "1"
class TestIterationNodeDefaultConfig:
"""Test suite for IterationNode get_default_config."""
def test_get_default_config_returns_dict(self):
"""Test get_default_config returns a dictionary."""
config = IterationNode.get_default_config()
assert isinstance(config, dict)
def test_get_default_config_type(self):
"""Test get_default_config includes type."""
config = IterationNode.get_default_config()
assert config.get("type") == "iteration"
def test_get_default_config_has_config_section(self):
"""Test get_default_config has config section."""
config = IterationNode.get_default_config()
assert "config" in config
assert isinstance(config["config"], dict)
def test_get_default_config_is_parallel_default(self):
"""Test get_default_config is_parallel default value."""
config = IterationNode.get_default_config()
assert config["config"]["is_parallel"] is False
def test_get_default_config_parallel_nums_default(self):
"""Test get_default_config parallel_nums default value."""
config = IterationNode.get_default_config()
assert config["config"]["parallel_nums"] == 10
def test_get_default_config_error_handle_mode_default(self):
"""Test get_default_config error_handle_mode default value."""
config = IterationNode.get_default_config()
assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED
def test_get_default_config_flatten_output_default(self):
"""Test get_default_config flatten_output default value."""
config = IterationNode.get_default_config()
assert config["config"]["flatten_output"] is True
def test_get_default_config_with_none_filters(self):
"""Test get_default_config with None filters."""
config = IterationNode.get_default_config(filters=None)
assert config is not None
assert "type" in config
def test_get_default_config_with_empty_filters(self):
"""Test get_default_config with empty filters."""
config = IterationNode.get_default_config(filters={})
assert config is not None
class TestIterationNodeInitialization:
"""Test suite for IterationNode initialization."""
def test_init_node_data_basic(self):
"""Test init_node_data with basic configuration."""
node = IterationNode.__new__(IterationNode)
data = {
"title": "Test Iteration",
"iterator_selector": ["start", "items"],
"output_selector": ["iteration", "result"],
}
node.init_node_data(data)
assert node._node_data.title == "Test Iteration"
assert node._node_data.iterator_selector == ["start", "items"]
def test_init_node_data_with_parallel(self):
"""Test init_node_data with parallel configuration."""
node = IterationNode.__new__(IterationNode)
data = {
"title": "Parallel Iteration",
"iterator_selector": ["node", "list"],
"output_selector": ["out", "result"],
"is_parallel": True,
"parallel_nums": 5,
}
node.init_node_data(data)
assert node._node_data.is_parallel is True
assert node._node_data.parallel_nums == 5
def test_init_node_data_with_error_handle_mode(self):
"""Test init_node_data with error handle mode."""
node = IterationNode.__new__(IterationNode)
data = {
"title": "Error Handle Test",
"iterator_selector": ["a", "b"],
"output_selector": ["c", "d"],
"error_handle_mode": "continue-on-error",
}
node.init_node_data(data)
assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR
def test_get_title(self):
"""Test _get_title method."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="My Iteration",
iterator_selector=["x"],
output_selector=["y"],
)
assert node._get_title() == "My Iteration"
def test_get_description_none(self):
"""Test _get_description returns None when not set."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
assert node._get_description() is None
def test_get_description_with_value(self):
"""Test _get_description with value."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
desc="This is a description",
iterator_selector=["a"],
output_selector=["b"],
)
assert node._get_description() == "This is a description"
def test_get_base_node_data(self):
"""Test get_base_node_data returns node data."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Base Test",
iterator_selector=["x"],
output_selector=["y"],
)
result = node.get_base_node_data()
assert result == node._node_data
class TestIterationNodeDataValidation:
"""Test suite for IterationNodeData validation scenarios."""
def test_valid_iteration_node_data(self):
"""Test valid IterationNodeData creation."""
data = IterationNodeData(
title="Valid Iteration",
iterator_selector=["start", "items"],
output_selector=["end", "result"],
)
assert data.title == "Valid Iteration"
def test_iteration_node_data_with_all_error_modes(self):
"""Test IterationNodeData with all error handle modes."""
modes = [
ErrorHandleMode.TERMINATED,
ErrorHandleMode.CONTINUE_ON_ERROR,
ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT,
]
for mode in modes:
data = IterationNodeData(
title=f"Test {mode}",
iterator_selector=["a"],
output_selector=["b"],
error_handle_mode=mode,
)
assert data.error_handle_mode == mode
def test_iteration_node_data_parallel_configuration(self):
"""Test IterationNodeData parallel configuration combinations."""
configs = [
(False, 10),
(True, 1),
(True, 5),
(True, 20),
(True, 100),
]
for is_parallel, parallel_nums in configs:
data = IterationNodeData(
title="Parallel Test",
iterator_selector=["x"],
output_selector=["y"],
is_parallel=is_parallel,
parallel_nums=parallel_nums,
)
assert data.is_parallel == is_parallel
assert data.parallel_nums == parallel_nums
def test_iteration_node_data_flatten_output_options(self):
"""Test IterationNodeData flatten_output options."""
data_flatten = IterationNodeData(
title="Flatten True",
iterator_selector=["a"],
output_selector=["b"],
flatten_output=True,
)
data_no_flatten = IterationNodeData(
title="Flatten False",
iterator_selector=["a"],
output_selector=["b"],
flatten_output=False,
)
assert data_flatten.flatten_output is True
assert data_no_flatten.flatten_output is False
def test_iteration_node_data_complex_selectors(self):
"""Test IterationNodeData with complex selectors."""
data = IterationNodeData(
title="Complex",
iterator_selector=["node1", "output", "data", "items", "list"],
output_selector=["iteration", "result", "value", "final"],
)
assert len(data.iterator_selector) == 5
assert len(data.output_selector) == 4
def test_iteration_node_data_single_element_selectors(self):
"""Test IterationNodeData with single element selectors."""
data = IterationNodeData(
title="Single",
iterator_selector=["items"],
output_selector=["result"],
)
assert len(data.iterator_selector) == 1
assert len(data.output_selector) == 1
class TestIterationNodeErrorStrategies:
"""Test suite for IterationNode error strategies."""
def test_get_error_strategy_default(self):
"""Test _get_error_strategy with default value."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
result = node._get_error_strategy()
assert result is None or result == node._node_data.error_strategy
def test_get_retry_config(self):
"""Test _get_retry_config method."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
result = node._get_retry_config()
assert result is not None
def test_get_default_value_dict(self):
"""Test _get_default_value_dict method."""
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Test",
iterator_selector=["a"],
output_selector=["b"],
)
result = node._get_default_value_dict()
assert isinstance(result, dict)