From 168d82f9b0ab71fa8c0dcc09f5f506cd919f249f Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Fri, 1 Aug 2025 13:16:04 +0800 Subject: [PATCH] feat(api): support boolean types in parameter extractor node --- .../nodes/parameter_extractor/entities.py | 93 +++++++--- .../workflow/nodes/parameter_extractor/exc.py | 25 +++ .../parameter_extractor_node.py | 159 +++++++++--------- .../workflow/utils/condition/processor.py | 6 +- 4 files changed, 183 insertions(+), 100 deletions(-) diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 916778d167..0bc92fb600 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,10 +1,46 @@ -from typing import Any, Literal, Optional +from typing import Annotated, Any, Literal, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import ( + BaseModel, + BeforeValidator, + Field, + field_validator, +) from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm import ModelConfig, VisionConfig +from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig + +_OLD_BOOL_TYPE_NAME = "bool" +_OLD_SELECT_TYPE_NAME = "select" + +_VALID_PARAMETER_TYPES = frozenset( + [ + SegmentType.STRING, # "string", + SegmentType.NUMBER, # "number", + SegmentType.BOOLEAN, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_BOOLEAN, + _OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node + _OLD_SELECT_TYPE_NAME, # string type with enumeration choices. + ] +) + + +def _validate_type(parameter_type: str) -> SegmentType: + if not isinstance(parameter_type, str): + raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}") + if parameter_type not in _VALID_PARAMETER_TYPES: + raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.") + + if parameter_type == _OLD_BOOL_TYPE_NAME: + return SegmentType.BOOLEAN + elif parameter_type == _OLD_SELECT_TYPE_NAME: + return SegmentType.STRING + return SegmentType(parameter_type) class _ParameterConfigError(Exception): @@ -17,11 +53,25 @@ class ParameterConfig(BaseModel): """ name: str - type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"] + type: Annotated[SegmentType, BeforeValidator(_validate_type)] options: Optional[list[str]] = None description: str required: bool + _is_old_select_type: bool = PrivateAttr(default=False) + + @model_validator(mode="wrap") + @classmethod + def log_failed_validation(cls, data: Any, handler: ModelWrapValidatorHandler[Self]) -> Self: + if not isinstance(data, dict): + return handler(data) + + original_type = data.get("type") + instance = handler(data) + if original_type == _OLD_SELECT_TYPE_NAME: + instance._is_old_select_type = True + return instance + @field_validator("name", mode="before") @classmethod def validate_name(cls, value) -> str: @@ -32,17 +82,20 @@ class ParameterConfig(BaseModel): return str(value) def is_array_type(self) -> bool: - return self.type in ("array[string]", "array[number]", "array[object]") + return self.type.is_array_type() - def element_type(self) -> Literal["string", "number", "object"]: - if self.type == "array[number]": - return "number" - elif self.type == "array[string]": - return "string" - elif self.type == "array[object]": - return "object" - else: - raise _ParameterConfigError(f"{self.type} is not array type.") + def element_type(self) -> SegmentType: + """Return the element type of the parameter. + + Raises a ValueError if the parameter's type is not an array type. + """ + element_type = self.type.element_type() + # At this point, self.type is guaranteed to be one of `ARRAY_STRING`, + # `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`. + # + # See: _VALID_PARAMETER_TYPES for reference. + assert element_type is not None, f"the element type should not be None, {self.type=}" + return element_type class ParameterExtractorNodeData(BaseNodeData): @@ -74,16 +127,18 @@ class ParameterExtractorNodeData(BaseNodeData): for parameter in self.parameters: parameter_schema: dict[str, Any] = {"description": parameter.description} - if parameter.type in {"string", "select"}: + if parameter.type == SegmentType.STRING: parameter_schema["type"] = "string" - elif parameter.type.startswith("array"): + elif parameter.type.is_array_type(): parameter_schema["type"] = "array" - nested_type = parameter.type[6:-1] - parameter_schema["items"] = {"type": nested_type} + element_type = parameter.type.element_type() + if element_type is None: + raise AssertionError("element type should not be None.") + parameter_schema["items"] = {"type": element_type.value} else: parameter_schema["type"] = parameter.type - if parameter.type == "select": + if parameter._is_old_select_type: parameter_schema["enum"] = parameter.options parameters["properties"][parameter.name] = parameter_schema diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py index 6511aba185..247518cf20 100644 --- a/api/core/workflow/nodes/parameter_extractor/exc.py +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -1,3 +1,8 @@ +from typing import Any + +from core.variables.types import SegmentType + + class ParameterExtractorNodeError(ValueError): """Base error for ParameterExtractorNode.""" @@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError): class InvalidModelModeError(ParameterExtractorNodeError): """Raised when the model mode is invalid.""" + + +class InvalidValueTypeError(ParameterExtractorNodeError): + def __init__( + self, + /, + parameter_name: str, + expected_type: SegmentType, + actual_type: SegmentType | None, + value: Any, + ) -> None: + message = ( + f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " + f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" + ) + super().__init__(message) + self.parameter_name = parameter_name + self.expected_type = expected_type + self.actual_type = actual_type + self.value = value diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 45c5e0a62c..7ce0152cdd 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -25,7 +25,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.variables.types import SegmentType +from core.variables.types import ArrayValidation, SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -38,16 +38,13 @@ from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( - InvalidArrayValueError, - InvalidBoolValueError, InvalidInvokeResultError, InvalidModelModeError, InvalidModelTypeError, InvalidNumberOfParametersError, - InvalidNumberValueError, InvalidSelectValueError, - InvalidStringValueError, InvalidTextContentTypeError, + InvalidValueTypeError, ModelSchemaNotFoundError, ParameterExtractorNodeError, RequiredParameterMissingError, @@ -548,9 +545,6 @@ class ParameterExtractorNode(BaseNode): return prompt_messages def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: - """ - Validate result. - """ if len(data.parameters) != len(result): raise InvalidNumberOfParametersError("Invalid number of parameters") @@ -558,101 +552,106 @@ class ParameterExtractorNode(BaseNode): if parameter.required and parameter.name not in result: raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: - raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - - if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): - raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}") - - if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): - raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}") - - if parameter.type == "string" and not isinstance(result.get(parameter.name), str): - raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}") - - if parameter.type.startswith("array"): - parameters = result.get(parameter.name) - if not isinstance(parameters, list): - raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}") - nested_type = parameter.type[6:-1] - for item in parameters: - if nested_type == "number" and not isinstance(item, int | float): - raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}") - if nested_type == "string" and not isinstance(item, str): - raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}") - if nested_type == "object" and not isinstance(item, dict): - raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + param_value = result.get(parameter.name) + if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL): + inferred_type = SegmentType.infer_segment_type(param_value) + raise InvalidValueTypeError( + parameter_name=parameter.name, + expected_type=parameter.type, + actual_type=inferred_type, + value=param_value, + ) + if parameter.type == SegmentType.STRING and parameter.options: + if param_value not in parameter.options: + raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") return result + @staticmethod + def _transform_number(value: int | float | str | bool) -> int | float | None: + """ + Attempts to transform the input into an integer or float. + + Returns: + int or float: The transformed number if the conversion is successful. + None: If the transformation fails. + + Note: + Boolean values `True` and `False` are converted to integers `1` and `0`, respectively. + This behavior ensures compatibility with existing workflows that may use boolean types as integers. + """ + if isinstance(value, bool): + return int(value) + elif isinstance(value, (int, float)): + return value + elif not isinstance(value, str): + return None + if "." in value: + try: + return float(value) + except ValueError: + return None + else: + try: + return int(value) + except ValueError: + return None + def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: """ Transform result into standard format. """ - transformed_result = {} + transformed_result: dict[str, Any] = {} for parameter in data.parameters: if parameter.name in result: + param_value = result[parameter.name] # transform value - if parameter.type == "number": - if isinstance(result[parameter.name], int | float): - transformed_result[parameter.name] = result[parameter.name] - elif isinstance(result[parameter.name], str): - try: - if "." in result[parameter.name]: - result[parameter.name] = float(result[parameter.name]) - else: - result[parameter.name] = int(result[parameter.name]) - except ValueError: - pass - else: - pass - # TODO: bool is not supported in the current version - # elif parameter.type == 'bool': - # if isinstance(result[parameter.name], bool): - # transformed_result[parameter.name] = bool(result[parameter.name]) - # elif isinstance(result[parameter.name], str): - # if result[parameter.name].lower() in ['true', 'false']: - # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') - # elif isinstance(result[parameter.name], int): - # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in {"string", "select"}: - if isinstance(result[parameter.name], str): - transformed_result[parameter.name] = result[parameter.name] + if parameter.type == SegmentType.NUMBER: + transformed = self._transform_number(param_value) + if transformed is not None: + transformed_result[parameter.name] = transformed + elif parameter.type == SegmentType.BOOLEAN: + if isinstance(result[parameter.name], (bool, int)): + transformed_result[parameter.name] = bool(result[parameter.name]) + # elif isinstance(result[parameter.name], str): + # if result[parameter.name].lower() in ["true", "false"]: + # transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true") + elif parameter.type == SegmentType.STRING: + if isinstance(param_value, str): + transformed_result[parameter.name] = param_value elif parameter.is_array_type(): - if isinstance(result[parameter.name], list): + if isinstance(param_value, list): nested_type = parameter.element_type() assert nested_type is not None segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) transformed_result[parameter.name] = segment_value - for item in result[parameter.name]: - if nested_type == "number": - if isinstance(item, int | float): - segment_value.value.append(item) - elif isinstance(item, str): - try: - if "." in item: - segment_value.value.append(float(item)) - else: - segment_value.value.append(int(item)) - except ValueError: - pass - elif nested_type == "string": + for item in param_value: + if nested_type == SegmentType.NUMBER: + transformed = self._transform_number(item) + if transformed is not None: + segment_value.value.append(transformed) + elif nested_type == SegmentType.STRING: if isinstance(item, str): segment_value.value.append(item) - elif nested_type == "object": + elif nested_type == SegmentType.OBJECT: if isinstance(item, dict): segment_value.value.append(item) + elif nested_type == SegmentType.BOOLEAN: + if isinstance(item, bool): + segment_value.value.append(item) if parameter.name not in transformed_result: - if parameter.type == "number": - transformed_result[parameter.name] = 0 - elif parameter.type == "bool": - transformed_result[parameter.name] = False - elif parameter.type in {"string", "select"}: - transformed_result[parameter.name] = "" - elif parameter.type.startswith("array"): + if parameter.type.is_array_type(): transformed_result[parameter.name] = build_segment_with_type( segment_type=SegmentType(parameter.type), value=[] ) + elif parameter.type in (SegmentType.STRING, SegmentType.SECRET): + transformed_result[parameter.name] = "" + elif parameter.type == SegmentType.NUMBER: + transformed_result[parameter.name] = 0 + elif parameter.type == SegmentType.BOOLEAN: + transformed_result[parameter.name] = False + else: + raise AssertionError("this statement should be unreachable.") return transformed_result diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 6bc1577c91..d74dca25af 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -3,6 +3,7 @@ from typing import Any, Literal, Union from core.file import FileAttribute, file_manager from core.variables import ArrayFileSegment +from core.variables.segments import BooleanSegment from core.workflow.entities.variable_pool import VariablePool from .entities import Condition, SubCondition, SupportedComparisonOperator @@ -51,6 +52,9 @@ class ConditionProcessor: expected_value = condition.value if isinstance(expected_value, str): expected_value = variable_pool.convert_template(expected_value).text + # Here we need to explicit convet the input string to boolean. + if isinstance(variable, BooleanSegment) and not variable.value_type.is_valid(expected_value): + raise TypeError(f"unexpected value: type={type(expected_value)}, value={expected_value}") input_conditions.append( { "actual_value": actual_value, @@ -77,7 +81,7 @@ def _evaluate_condition( *, operator: SupportedComparisonOperator, value: Any, - expected: Union[str, Sequence[str], None], + expected: Union[str, Sequence[str], bool, None], ) -> bool: match operator: case "contains":