diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/core/workflow/nodes/list_operator/entities.py index 75df784a92..550ed0b94c 100644 --- a/api/core/workflow/nodes/list_operator/entities.py +++ b/api/core/workflow/nodes/list_operator/entities.py @@ -1,35 +1,40 @@ from collections.abc import Sequence -from typing import Literal +from enum import StrEnum from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseNodeData -_Condition = Literal[ + +class FilterOperator(StrEnum): # string conditions - "contains", - "start with", - "end with", - "is", - "in", - "empty", - "not contains", - "is not", - "not in", - "not empty", + CONTAINS = "contains" + START_WITH = "start with" + END_WITH = "end with" + IS = "is" + IN = "in" + EMPTY = "empty" + NOT_CONTAINS = "not contains" + IS_NOT = "is not" + NOT_IN = "not in" + NOT_EMPTY = "not empty" # number conditions - "=", - "≠", - "<", - ">", - "≥", - "≤", -] + EQUAL = "=" + NOT_EQUAL = "≠" + LESS_THAN = "<" + GREATER_THAN = ">" + GREATER_THAN_OR_EQUAL = "≥" + LESS_THAN_OR_EQUAL = "≤" + + +class Order(StrEnum): + ASC = "asc" + DESC = "desc" class FilterCondition(BaseModel): key: str = "" - comparison_operator: _Condition = "contains" + comparison_operator: FilterOperator = FilterOperator.CONTAINS value: str | Sequence[str] = "" @@ -38,10 +43,10 @@ class FilterBy(BaseModel): conditions: Sequence[FilterCondition] = Field(default_factory=list) -class OrderBy(BaseModel): +class OrderByConfig(BaseModel): enabled: bool = False key: str = "" - value: Literal["asc", "desc"] = "asc" + value: Order = Order.ASC class Limit(BaseModel): @@ -57,6 +62,6 @@ class ExtractConfig(BaseModel): class ListOperatorNodeData(BaseNodeData): variable: Sequence[str] = Field(default_factory=list) filter_by: FilterBy - order_by: OrderBy + order_by: OrderByConfig limit: Limit extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index d2e022dc9d..6f1f05164f 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,18 +1,41 @@ +import json from collections.abc import Callable, Mapping, Sequence -from typing import Any, Literal, Optional, Union +from typing import Any, Optional, TypeAlias, TypeVar from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from core.variables.segments import ArrayAnySegment, ArraySegment +from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.enums import ErrorStrategy, NodeType -from .entities import ListOperatorNodeData +from .entities import FilterOperator, ListOperatorNodeData, Order from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError +_SUPPORTED_TYPES_TUPLE = ( + ArrayFileSegment, + ArrayNumberSegment, + ArrayStringSegment, + ArrayBooleanSegment, +) +_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment + + +_T = TypeVar("_T") + + +def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: + """Returns the negation of a given filter function. If the original filter + returns `True` for a value, the negated filter will return `False`, and vice versa. + """ + + def wrapper(value: _T) -> bool: + return not filter_(value) + + return wrapper + class ListOperatorNode(BaseNode): _node_type = NodeType.LIST_OPERATOR @@ -69,11 +92,8 @@ class ListOperatorNode(BaseNode): process_data=process_data, outputs=outputs, ) - if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): - error_message = ( - f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " - "or ArrayStringSegment" - ) + if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): + error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}" return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) @@ -122,9 +142,7 @@ class ListOperatorNode(BaseNode): outputs=outputs, ) - def _apply_filter( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: filter_func: Callable[[Any], bool] result: list[Any] = [] for condition in self._node_data.filter_by.conditions: @@ -154,33 +172,39 @@ class ListOperatorNode(BaseNode): ) result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayBooleanSegment): + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + value = json.loads(value) + if not isinstance(value, bool): + raise ValueError(f"value for boolean filter should be boolean values, got {value}") + filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + else: + raise AssertionError("this statment should be unreachable.") return variable - def _apply_order( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self._node_data.order_by.value, array=variable.value) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self._node_data.order_by.value, array=variable.value) + def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: + if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): + result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): result = _order_file( order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value ) variable = variable.model_copy(update={"value": result}) + else: + raise AssertionError("this statement should be unreachable") + return variable - def _apply_slice( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: result = variable.value[: self._node_data.limit.size] return variable.model_copy(update={"value": result}) - def _extract_slice( - self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] - ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) if value < 1: raise ValueError(f"Invalid serial index: must be >= 1, got {value}") @@ -232,11 +256,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo case "empty": return lambda x: x == "" case "not contains": - return lambda x: not _contains(value)(x) + return _negation(_contains(value)) case "is not": - return lambda x: not _is(value)(x) + return _negation(_is(value)) case "not in": - return lambda x: not _in(value)(x) + return _negation(_in(value)) case "not empty": return lambda x: x != "" case _: @@ -248,7 +272,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab case "in": return _in(value) case "not in": - return lambda x: not _in(value)(x) + return _negation(_in(value)) case _: raise InvalidConditionError(f"Invalid condition: {condition}") @@ -271,6 +295,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ raise InvalidConditionError(f"Invalid condition: {condition}") +def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]: + match condition: + case FilterOperator.IS: + return _is(value) + case FilterOperator.IS_NOT: + return _negation(_is(value)) + case _: + raise InvalidConditionError(f"Invalid condition: {condition}") + + def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): @@ -298,7 +332,7 @@ def _endswith(value: str) -> Callable[[str], bool]: return lambda x: x.endswith(value) -def _is(value: str) -> Callable[[str], bool]: +def _is(value: _T) -> Callable[[_T], bool]: return lambda x: x == value @@ -330,21 +364,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]: return lambda x: x >= value -def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]): - return sorted(array, key=lambda x: x, reverse=order == "desc") - - -def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): - return sorted(array, key=lambda x: x, reverse=order == "desc") - - -def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): +def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]): extract_func: Callable[[File], Any] if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: extract_func = _get_file_extract_string_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) elif order_by == "size": extract_func = _get_file_extract_number_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) else: raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 5fc9eab2df..b03ce3e1c6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -11,7 +11,7 @@ from core.workflow.nodes.list_operator.entities import ( FilterCondition, Limit, ListOperatorNodeData, - OrderBy, + Order, ) from core.workflow.nodes.list_operator.exc import InvalidKeyError from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func @@ -27,7 +27,7 @@ def list_operator_node(): FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT]) ], ), - "order_by": OrderBy(enabled=False, value="asc"), + "order_by": Order(enabled=False, value="asc"), "limit": Limit(enabled=False, size=0), "extract_by": ExtractConfig(enabled=False, serial="1"), "title": "Test Title",