feat(api): support boolean types in parameter extractor node

This commit is contained in:
QuantumGhost 2025-08-01 13:16:04 +08:00
parent 0e3ccb4dcc
commit 168d82f9b0
4 changed files with 183 additions and 100 deletions

View File

@ -1,10 +1,46 @@
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import (
BaseModel,
BeforeValidator,
Field,
field_validator,
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
_OLD_BOOL_TYPE_NAME = "bool"
_OLD_SELECT_TYPE_NAME = "select"
_VALID_PARAMETER_TYPES = frozenset(
[
SegmentType.STRING, # "string",
SegmentType.NUMBER, # "number",
SegmentType.BOOLEAN,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_BOOLEAN,
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
]
)
def _validate_type(parameter_type: str) -> SegmentType:
if not isinstance(parameter_type, str):
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
if parameter_type not in _VALID_PARAMETER_TYPES:
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
if parameter_type == _OLD_BOOL_TYPE_NAME:
return SegmentType.BOOLEAN
elif parameter_type == _OLD_SELECT_TYPE_NAME:
return SegmentType.STRING
return SegmentType(parameter_type)
class _ParameterConfigError(Exception):
@ -17,11 +53,25 @@ class ParameterConfig(BaseModel):
"""
name: str
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
options: Optional[list[str]] = None
description: str
required: bool
_is_old_select_type: bool = PrivateAttr(default=False)
@model_validator(mode="wrap")
@classmethod
def log_failed_validation(cls, data: Any, handler: ModelWrapValidatorHandler[Self]) -> Self:
if not isinstance(data, dict):
return handler(data)
original_type = data.get("type")
instance = handler(data)
if original_type == _OLD_SELECT_TYPE_NAME:
instance._is_old_select_type = True
return instance
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
@ -32,17 +82,20 @@ class ParameterConfig(BaseModel):
return str(value)
def is_array_type(self) -> bool:
return self.type in ("array[string]", "array[number]", "array[object]")
return self.type.is_array_type()
def element_type(self) -> Literal["string", "number", "object"]:
if self.type == "array[number]":
return "number"
elif self.type == "array[string]":
return "string"
elif self.type == "array[object]":
return "object"
else:
raise _ParameterConfigError(f"{self.type} is not array type.")
def element_type(self) -> SegmentType:
"""Return the element type of the parameter.
Raises a ValueError if the parameter's type is not an array type.
"""
element_type = self.type.element_type()
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
#
# See: _VALID_PARAMETER_TYPES for reference.
assert element_type is not None, f"the element type should not be None, {self.type=}"
return element_type
class ParameterExtractorNodeData(BaseNodeData):
@ -74,16 +127,18 @@ class ParameterExtractorNodeData(BaseNodeData):
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type in {"string", "select"}:
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
elif parameter.type.startswith("array"):
elif parameter.type.is_array_type():
parameter_schema["type"] = "array"
nested_type = parameter.type[6:-1]
parameter_schema["items"] = {"type": nested_type}
element_type = parameter.type.element_type()
if element_type is None:
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
if parameter.type == "select":
if parameter._is_old_select_type:
parameter_schema["enum"] = parameter.options
parameters["properties"][parameter.name] = parameter_schema

View File

@ -1,3 +1,8 @@
from typing import Any
from core.variables.types import SegmentType
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""
class InvalidValueTypeError(ParameterExtractorNodeError):
def __init__(
self,
/,
parameter_name: str,
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
) -> None:
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
)
super().__init__(message)
self.parameter_name = parameter_name
self.expected_type = expected_type
self.actual_type = actual_type
self.value = value

View File

@ -25,7 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -38,16 +38,13 @@ from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
InvalidValueTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
@ -548,9 +545,6 @@ class ParameterExtractorNode(BaseNode):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Validate result.
"""
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -558,101 +552,106 @@ class ParameterExtractorNode(BaseNode):
if parameter.required and parameter.name not in result:
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict):
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
param_value = result.get(parameter.name)
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
inferred_type = SegmentType.infer_segment_type(param_value)
raise InvalidValueTypeError(
parameter_name=parameter.name,
expected_type=parameter.type,
actual_type=inferred_type,
value=param_value,
)
if parameter.type == SegmentType.STRING and parameter.options:
if param_value not in parameter.options:
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
return result
@staticmethod
def _transform_number(value: int | float | str | bool) -> int | float | None:
"""
Attempts to transform the input into an integer or float.
Returns:
int or float: The transformed number if the conversion is successful.
None: If the transformation fails.
Note:
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
"""
if isinstance(value, bool):
return int(value)
elif isinstance(value, (int, float)):
return value
elif not isinstance(value, str):
return None
if "." in value:
try:
return float(value)
except ValueError:
return None
else:
try:
return int(value)
except ValueError:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Transform result into standard format.
"""
transformed_result = {}
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == "number":
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if "." in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
except ValueError:
pass
else:
pass
# TODO: bool is not supported in the current version
# elif parameter.type == 'bool':
# if isinstance(result[parameter.name], bool):
# transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ['true', 'false']:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ["true", "false"]:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
elif parameter.type == SegmentType.STRING:
if isinstance(param_value, str):
transformed_result[parameter.name] = param_value
elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
if isinstance(param_value, list):
nested_type = parameter.element_type()
assert nested_type is not None
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
segment_value.value.append(float(item))
else:
segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
for item in param_value:
if nested_type == SegmentType.NUMBER:
transformed = self._transform_number(item)
if transformed is not None:
segment_value.value.append(transformed)
elif nested_type == SegmentType.STRING:
if isinstance(item, str):
segment_value.value.append(item)
elif nested_type == "object":
elif nested_type == SegmentType.OBJECT:
if isinstance(item, dict):
segment_value.value.append(item)
elif nested_type == SegmentType.BOOLEAN:
if isinstance(item, bool):
segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
transformed_result[parameter.name] = 0
elif parameter.type == "bool":
transformed_result[parameter.name] = False
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
if parameter.type.is_array_type():
transformed_result[parameter.name] = build_segment_with_type(
segment_type=SegmentType(parameter.type), value=[]
)
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
transformed_result[parameter.name] = ""
elif parameter.type == SegmentType.NUMBER:
transformed_result[parameter.name] = 0
elif parameter.type == SegmentType.BOOLEAN:
transformed_result[parameter.name] = False
else:
raise AssertionError("this statement should be unreachable.")
return transformed_result

View File

@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import BooleanSegment
from core.workflow.entities.variable_pool import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator
@ -51,6 +52,9 @@ class ConditionProcessor:
expected_value = condition.value
if isinstance(expected_value, str):
expected_value = variable_pool.convert_template(expected_value).text
# Here we need to explicit convet the input string to boolean.
if isinstance(variable, BooleanSegment) and not variable.value_type.is_valid(expected_value):
raise TypeError(f"unexpected value: type={type(expected_value)}, value={expected_value}")
input_conditions.append(
{
"actual_value": actual_value,
@ -77,7 +81,7 @@ def _evaluate_condition(
*,
operator: SupportedComparisonOperator,
value: Any,
expected: Union[str, Sequence[str], None],
expected: Union[str, Sequence[str], bool, None],
) -> bool:
match operator:
case "contains":