diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index 36662a065e..c06a62d1e7 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -25,7 +25,7 @@ class GraphRuntimeState(BaseModel): llm_usage: LLMUsage | None = None, outputs: dict[str, Any] | None = None, node_run_steps: int = 0, - **kwargs, + **kwargs: object, ): """Initialize the GraphRuntimeState with validation.""" super().__init__(**kwargs) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index fb0794844e..db5cbeca03 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -14,7 +14,7 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_V from core.workflow.system_variable import SystemVariable from factories import variable_factory -VariableValue = Union[str, int, float, dict, list, File] +VariableValue = Union[str, int, float, dict[str, object], list[object], File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") @@ -40,11 +40,11 @@ class VariablePool(BaseModel): ) environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", - default_factory=list, + default_factory=list[VariableUnion], ) conversation_variables: Sequence[VariableUnion] = Field( description="Conversation variables.", - default_factory=list, + default_factory=list[VariableUnion], ) def model_post_init(self, context: Any, /) -> None: @@ -191,7 +191,7 @@ class VariablePool(BaseModel): def convert_template(self, template: str, /): parts = VARIABLE_PATTERN.split(template) - segments = [] + segments: list[Segment] = [] for part in filter(lambda x: x, parts): if "." in part and (variable := self.get(part.split("."))): segments.append(variable) diff --git a/api/core/workflow/graph/graph_runtime_state_protocol.py b/api/core/workflow/graph/graph_runtime_state_protocol.py index a5c8db333a..d7961405ca 100644 --- a/api/core/workflow/graph/graph_runtime_state_protocol.py +++ b/api/core/workflow/graph/graph_runtime_state_protocol.py @@ -1,16 +1,18 @@ +from collections.abc import Mapping from typing import Any, Protocol from core.model_runtime.entities.llm_entities import LLMUsage +from core.variables.segments import Segment class ReadOnlyVariablePool(Protocol): """Read-only interface for VariablePool.""" - def get(self, node_id: str, variable_key: str) -> Any: + def get(self, node_id: str, variable_key: str) -> Segment | None: """Get a variable value (read-only).""" ... - def get_all_by_node(self, node_id: str) -> dict[str, Any]: + def get_all_by_node(self, node_id: str) -> Mapping[str, object]: """Get all variables for a node (read-only).""" ... diff --git a/api/core/workflow/graph/read_only_state_wrapper.py b/api/core/workflow/graph/read_only_state_wrapper.py index 3562106a4c..255bb5adee 100644 --- a/api/core/workflow/graph/read_only_state_wrapper.py +++ b/api/core/workflow/graph/read_only_state_wrapper.py @@ -1,7 +1,9 @@ +from collections.abc import Mapping from copy import deepcopy from typing import Any from core.model_runtime.entities.llm_entities import LLMUsage +from core.variables.segments import Segment from core.workflow.entities.graph_runtime_state import GraphRuntimeState from core.workflow.entities.variable_pool import VariablePool @@ -12,19 +14,18 @@ class ReadOnlyVariablePoolWrapper: def __init__(self, variable_pool: VariablePool): self._variable_pool = variable_pool - def get(self, node_id: str, variable_key: str) -> Any: + def get(self, node_id: str, variable_key: str) -> Segment | None: """Get a variable value (returns a defensive copy).""" - value = self._variable_pool.get(node_id, variable_key) + value = self._variable_pool.get([node_id, variable_key]) return deepcopy(value) if value is not None else None - def get_all_by_node(self, node_id: str) -> dict[str, Any]: + def get_all_by_node(self, node_id: str) -> Mapping[str, object]: """Get all variables for a node (returns defensive copies).""" - variables = {} + variables: dict[str, object] = {} if node_id in self._variable_pool.variable_dictionary: for key, var in self._variable_pool.variable_dictionary[node_id].items(): - # FIXME(-LAN-): Handle the actual Variable object structure - value = var.value if hasattr(var, "value") else var - variables[key] = deepcopy(value) + # Variables have a value property that contains the actual data + variables[key] = deepcopy(var.value) return variables diff --git a/api/core/workflow/node_events/base.py b/api/core/workflow/node_events/base.py index 3e9e239d30..7fec47e21f 100644 --- a/api/core/workflow/node_events/base.py +++ b/api/core/workflow/node_events/base.py @@ -13,6 +13,11 @@ class NodeEventBase(BaseModel): pass +def _default_metadata(): + v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + return v + + class NodeRunResult(BaseModel): """ Node Run Result. @@ -23,7 +28,7 @@ class NodeRunResult(BaseModel): inputs: Mapping[str, Any] = Field(default_factory=dict) process_data: Mapping[str, Any] = Field(default_factory=dict) outputs: Mapping[str, Any] = Field(default_factory=dict) - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=dict) + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata) llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) edge_source_handle: str = "source" # source handle id of node with multiple branches diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 8689aa987b..f4bbe9c3c3 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import Any, Literal, NamedTuple, Union +from typing import Literal, NamedTuple from core.file import FileAttribute, file_manager from core.variables import ArrayFileSegment @@ -10,7 +10,7 @@ from core.workflow.entities import VariablePool from .entities import Condition, SubCondition, SupportedComparisonOperator -def _convert_to_bool(value: Any) -> bool: +def _convert_to_bool(value: object) -> bool: if isinstance(value, int): return bool(value) @@ -23,7 +23,7 @@ def _convert_to_bool(value: Any) -> bool: class ConditionCheckResult(NamedTuple): - inputs: Sequence[Mapping[str, Any]] + inputs: Sequence[Mapping[str, object]] group_results: Sequence[bool] final_result: bool @@ -36,7 +36,7 @@ class ConditionProcessor: conditions: Sequence[Condition], operator: Literal["and", "or"], ) -> ConditionCheckResult: - input_conditions: list[Mapping[str, Any]] = [] + input_conditions: list[Mapping[str, object]] = [] group_results: list[bool] = [] for condition in conditions: @@ -103,8 +103,8 @@ class ConditionProcessor: def _evaluate_condition( *, operator: SupportedComparisonOperator, - value: Any, - expected: Union[str, Sequence[str], bool | Sequence[bool], None], + value: object, + expected: str | Sequence[str] | bool | Sequence[bool] | None, ) -> bool: match operator: case "contains": @@ -144,7 +144,17 @@ def _evaluate_condition( case "not in": return _assert_not_in(value=value, expected=expected) case "all of" if isinstance(expected, list): - return _assert_all_of(value=value, expected=expected) + # Type narrowing: at this point expected is a list, could be list[str] or list[bool] + if all(isinstance(item, str) for item in expected): + # Create a new typed list to satisfy type checker + str_list: list[str] = [item for item in expected if isinstance(item, str)] + return _assert_all_of(value=value, expected=str_list) + elif all(isinstance(item, bool) for item in expected): + # Create a new typed list to satisfy type checker + bool_list: list[bool] = [item for item in expected if isinstance(item, bool)] + return _assert_all_of_bool(value=value, expected=bool_list) + else: + raise ValueError("all of operator expects homogeneous list of strings or booleans") case "exists": return _assert_exists(value=value) case "not exists": @@ -153,55 +163,73 @@ def _evaluate_condition( raise ValueError(f"Unsupported operator: {operator}") -def _assert_contains(*, value: Any, expected: Any) -> bool: +def _assert_contains(*, value: object, expected: object) -> bool: if not value: return False if not isinstance(value, (str, list)): raise ValueError("Invalid actual value type: string or array") - if expected not in value: - return False + # Type checking ensures value is str or list at this point + if isinstance(value, str): + if not isinstance(expected, str): + expected = str(expected) + if expected not in value: + return False + else: # value is list + if expected not in value: + return False return True -def _assert_not_contains(*, value: Any, expected: Any) -> bool: +def _assert_not_contains(*, value: object, expected: object) -> bool: if not value: return True if not isinstance(value, (str, list)): raise ValueError("Invalid actual value type: string or array") - if expected in value: - return False + # Type checking ensures value is str or list at this point + if isinstance(value, str): + if not isinstance(expected, str): + expected = str(expected) + if expected in value: + return False + else: # value is list + if expected in value: + return False return True -def _assert_start_with(*, value: Any, expected: Any) -> bool: +def _assert_start_with(*, value: object, expected: object) -> bool: if not value: return False if not isinstance(value, str): raise ValueError("Invalid actual value type: string") + if not isinstance(expected, str): + raise ValueError("Expected value must be a string for startswith") if not value.startswith(expected): return False return True -def _assert_end_with(*, value: Any, expected: Any) -> bool: +def _assert_end_with(*, value: object, expected: object) -> bool: if not value: return False if not isinstance(value, str): raise ValueError("Invalid actual value type: string") + if not isinstance(expected, str): + raise ValueError("Expected value must be a string for endswith") if not value.endswith(expected): return False return True -def _assert_is(*, value: Any, expected: Any) -> bool: +def _assert_is(*, value: object, expected: object) -> bool: if value is None: return False @@ -213,7 +241,7 @@ def _assert_is(*, value: Any, expected: Any) -> bool: return True -def _assert_is_not(*, value: Any, expected: Any) -> bool: +def _assert_is_not(*, value: object, expected: object) -> bool: if value is None: return False @@ -225,19 +253,19 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool: return True -def _assert_empty(*, value: Any) -> bool: +def _assert_empty(*, value: object) -> bool: if not value: return True return False -def _assert_not_empty(*, value: Any) -> bool: +def _assert_not_empty(*, value: object) -> bool: if value: return True return False -def _assert_equal(*, value: Any, expected: Any) -> bool: +def _assert_equal(*, value: object, expected: object) -> bool: if value is None: return False @@ -246,10 +274,16 @@ def _assert_equal(*, value: Any, expected: Any) -> bool: # Handle boolean comparison if isinstance(value, bool): + if not isinstance(expected, (bool, int, str)): + raise ValueError(f"Cannot convert {type(expected)} to bool") expected = bool(expected) elif isinstance(value, int): + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to int") expected = int(expected) else: + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to float") expected = float(expected) if value != expected: @@ -257,7 +291,7 @@ def _assert_equal(*, value: Any, expected: Any) -> bool: return True -def _assert_not_equal(*, value: Any, expected: Any) -> bool: +def _assert_not_equal(*, value: object, expected: object) -> bool: if value is None: return False @@ -266,10 +300,16 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool: # Handle boolean comparison if isinstance(value, bool): + if not isinstance(expected, (bool, int, str)): + raise ValueError(f"Cannot convert {type(expected)} to bool") expected = bool(expected) elif isinstance(value, int): + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to int") expected = int(expected) else: + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to float") expected = float(expected) if value == expected: @@ -277,7 +317,7 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool: return True -def _assert_greater_than(*, value: Any, expected: Any) -> bool: +def _assert_greater_than(*, value: object, expected: object) -> bool: if value is None: return False @@ -285,8 +325,12 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool: raise ValueError("Invalid actual value type: number") if isinstance(value, int): + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to int") expected = int(expected) else: + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to float") expected = float(expected) if value <= expected: @@ -294,7 +338,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool: return True -def _assert_less_than(*, value: Any, expected: Any) -> bool: +def _assert_less_than(*, value: object, expected: object) -> bool: if value is None: return False @@ -302,8 +346,12 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool: raise ValueError("Invalid actual value type: number") if isinstance(value, int): + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to int") expected = int(expected) else: + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to float") expected = float(expected) if value >= expected: @@ -311,7 +359,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool: return True -def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool: +def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool: if value is None: return False @@ -319,8 +367,12 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool: raise ValueError("Invalid actual value type: number") if isinstance(value, int): + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to int") expected = int(expected) else: + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to float") expected = float(expected) if value < expected: @@ -328,7 +380,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool: return True -def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool: +def _assert_less_than_or_equal(*, value: object, expected: object) -> bool: if value is None: return False @@ -336,8 +388,12 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool: raise ValueError("Invalid actual value type: number") if isinstance(value, int): + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to int") expected = int(expected) else: + if not isinstance(expected, (int, float, str)): + raise ValueError(f"Cannot convert {type(expected)} to float") expected = float(expected) if value > expected: @@ -345,19 +401,19 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool: return True -def _assert_null(*, value: Any) -> bool: +def _assert_null(*, value: object) -> bool: if value is None: return True return False -def _assert_not_null(*, value: Any) -> bool: +def _assert_not_null(*, value: object) -> bool: if value is not None: return True return False -def _assert_in(*, value: Any, expected: Any) -> bool: +def _assert_in(*, value: object, expected: object) -> bool: if not value: return False @@ -369,7 +425,7 @@ def _assert_in(*, value: Any, expected: Any) -> bool: return True -def _assert_not_in(*, value: Any, expected: Any) -> bool: +def _assert_not_in(*, value: object, expected: object) -> bool: if not value: return True @@ -381,20 +437,33 @@ def _assert_not_in(*, value: Any, expected: Any) -> bool: return True -def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool: +def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool: if not value: return False - if not all(item in value for item in expected): + # Ensure value is a container that supports 'in' operator + if not isinstance(value, (list, tuple, set, str)): return False - return True + + return all(item in value for item in expected) -def _assert_exists(*, value: Any) -> bool: +def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool: + if not value: + return False + + # Ensure value is a container that supports 'in' operator + if not isinstance(value, (list, tuple, set)): + return False + + return all(item in value for item in expected) + + +def _assert_exists(*, value: object) -> bool: return value is not None -def _assert_not_exists(*, value: Any) -> bool: +def _assert_not_exists(*, value: object) -> bool: return value is None @@ -404,7 +473,7 @@ def _process_sub_conditions( operator: Literal["and", "or"], ) -> bool: files = variable.value - group_results = [] + group_results: list[bool] = [] for condition in sub_conditions: key = FileAttribute(condition.key) values = [file_manager.get_attr(file=file, attr=key) for file in files] @@ -415,14 +484,14 @@ def _process_sub_conditions( if expected_value and not expected_value.startswith("."): expected_value = "." + expected_value - normalized_values = [] + normalized_values: list[object] = [] for value in values: if value and isinstance(value, str): if not value.startswith("."): value = "." + value normalized_values.append(value) values = normalized_values - sub_group_results = [ + sub_group_results: list[bool] = [ _evaluate_condition( value=value, operator=condition.comparison_operator,