feat(api): Implement boolean types support for list operator node

This commit is contained in:
QuantumGhost 2025-07-30 10:54:07 +08:00
parent 111c34e50f
commit 3751702efc
3 changed files with 97 additions and 66 deletions

View File

@ -1,35 +1,40 @@
from collections.abc import Sequence
from typing import Literal
from enum import StrEnum
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
_Condition = Literal[
class FilterOperator(StrEnum):
# string conditions
"contains",
"start with",
"end with",
"is",
"in",
"empty",
"not contains",
"is not",
"not in",
"not empty",
CONTAINS = "contains"
START_WITH = "start with"
END_WITH = "end with"
IS = "is"
IN = "in"
EMPTY = "empty"
NOT_CONTAINS = "not contains"
IS_NOT = "is not"
NOT_IN = "not in"
NOT_EMPTY = "not empty"
# number conditions
"=",
"",
"<",
">",
"",
"",
]
EQUAL = "="
NOT_EQUAL = ""
LESS_THAN = "<"
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL = ""
LESS_THAN_OR_EQUAL = ""
class Order(StrEnum):
ASC = "asc"
DESC = "desc"
class FilterCondition(BaseModel):
key: str = ""
comparison_operator: _Condition = "contains"
comparison_operator: FilterOperator = FilterOperator.CONTAINS
value: str | Sequence[str] = ""
@ -38,10 +43,10 @@ class FilterBy(BaseModel):
conditions: Sequence[FilterCondition] = Field(default_factory=list)
class OrderBy(BaseModel):
class OrderByConfig(BaseModel):
enabled: bool = False
key: str = ""
value: Literal["asc", "desc"] = "asc"
value: Order = Order.ASC
class Limit(BaseModel):
@ -57,6 +62,6 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderBy
order_by: OrderByConfig
limit: Limit
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)

View File

@ -1,18 +1,41 @@
import json
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Optional, Union
from typing import Any, Optional, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .entities import FilterOperator, ListOperatorNodeData, Order
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
_SUPPORTED_TYPES_TUPLE = (
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
ArrayBooleanSegment,
)
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
_T = TypeVar("_T")
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
"""Returns the negation of a given filter function. If the original filter
returns `True` for a value, the negated filter will return `False`, and vice versa.
"""
def wrapper(value: _T) -> bool:
return not filter_(value)
return wrapper
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
@ -69,11 +92,8 @@ class ListOperatorNode(BaseNode):
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -122,9 +142,7 @@ class ListOperatorNode(BaseNode):
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self._node_data.filter_by.conditions:
@ -154,33 +172,39 @@ class ListOperatorNode(BaseNode):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayBooleanSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
value = json.loads(value)
if not isinstance(value, bool):
raise ValueError(f"value for boolean filter should be boolean values, got {value}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statment should be unreachable.")
return variable
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
else:
raise AssertionError("this statement should be unreachable")
return variable
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
@ -232,11 +256,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "empty":
return lambda x: x == ""
case "not contains":
return lambda x: not _contains(value)(x)
return _negation(_contains(value))
case "is not":
return lambda x: not _is(value)(x)
return _negation(_is(value))
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case "not empty":
return lambda x: x != ""
case _:
@ -248,7 +272,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "in":
return _in(value)
case "not in":
return lambda x: not _in(value)(x)
return _negation(_in(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
@ -271,6 +295,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
match condition:
case FilterOperator.IS:
return _is(value)
case FilterOperator.IS_NOT:
return _negation(_is(value))
case _:
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
@ -298,7 +332,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
return lambda x: x.endswith(value)
def _is(value: str) -> Callable[[str], bool]:
def _is(value: _T) -> Callable[[_T], bool]:
return lambda x: x == value
@ -330,21 +364,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
return lambda x: x >= value
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
return sorted(array, key=lambda x: x, reverse=order == "desc")
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
extract_func: Callable[[File], Any]
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
extract_func = _get_file_extract_string_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
elif order_by == "size":
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
else:
raise InvalidKeyError(f"Invalid order key: {order_by}")

View File

@ -11,7 +11,7 @@ from core.workflow.nodes.list_operator.entities import (
FilterCondition,
Limit,
ListOperatorNodeData,
OrderBy,
Order,
)
from core.workflow.nodes.list_operator.exc import InvalidKeyError
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
@ -27,7 +27,7 @@ def list_operator_node():
FilterCondition(key="type", comparison_operator="in", value=[FileType.IMAGE, FileType.DOCUMENT])
],
),
"order_by": OrderBy(enabled=False, value="asc"),
"order_by": Order(enabled=False, value="asc"),
"limit": Limit(enabled=False, size=0),
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",