diff --git a/api/core/workflow/entities/graph_runtime_state.py b/api/core/workflow/entities/graph_runtime_state.py index 19aa0d27e6..e9bf2ea26c 100644 --- a/api/core/workflow/entities/graph_runtime_state.py +++ b/api/core/workflow/entities/graph_runtime_state.py @@ -1,6 +1,7 @@ +from copy import deepcopy from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, PrivateAttr from core.model_runtime.entities.llm_entities import LLMUsage @@ -8,21 +9,132 @@ from .variable_pool import VariablePool class GraphRuntimeState(BaseModel): - variable_pool: VariablePool = Field(..., description="variable pool") - """variable pool""" + # Private attributes to prevent direct modification + _variable_pool: VariablePool = PrivateAttr() + _start_at: float = PrivateAttr() + _total_tokens: int = PrivateAttr(default=0) + _llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage) + _outputs: dict[str, Any] = PrivateAttr(default_factory=dict) + _node_run_steps: int = PrivateAttr(default=0) - start_at: float = Field(..., description="start time") - """start time""" - total_tokens: int = 0 - """total tokens""" - llm_usage: LLMUsage = LLMUsage.empty_usage() - """llm usage info""" + def __init__( + self, + variable_pool: VariablePool, + start_at: float, + total_tokens: int = 0, + llm_usage: LLMUsage | None = None, + outputs: dict[str, Any] | None = None, + node_run_steps: int = 0, + **kwargs, + ): + """Initialize the GraphRuntimeState with validation.""" + super().__init__(**kwargs) - # The `outputs` field stores the final output values generated by executing workflows or chatflows. - # - # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent - # after a serialization and deserialization round trip. - outputs: dict[str, Any] = Field(default_factory=dict) + # Initialize private attributes with validation + self._variable_pool = variable_pool - node_run_steps: int = 0 - """node run steps""" + self._start_at = start_at + + if total_tokens < 0: + raise ValueError("total_tokens must be non-negative") + self._total_tokens = total_tokens + + if llm_usage is None: + llm_usage = LLMUsage.empty_usage() + self._llm_usage = llm_usage + + if outputs is None: + outputs = {} + self._outputs = deepcopy(outputs) + + if node_run_steps < 0: + raise ValueError("node_run_steps must be non-negative") + self._node_run_steps = node_run_steps + + @property + def variable_pool(self) -> VariablePool: + """Get the variable pool.""" + return self._variable_pool + + @variable_pool.setter + def variable_pool(self, value: VariablePool) -> None: + """Set the variable pool.""" + self._variable_pool = value + + @property + def start_at(self) -> float: + """Get the start time.""" + return self._start_at + + @start_at.setter + def start_at(self, value: float) -> None: + """Set the start time.""" + self._start_at = value + + @property + def total_tokens(self) -> int: + """Get the total tokens count.""" + return self._total_tokens + + @total_tokens.setter + def total_tokens(self, value: int): + """Set the total tokens count.""" + if value < 0: + raise ValueError("total_tokens must be non-negative") + self._total_tokens = value + + @property + def llm_usage(self) -> LLMUsage: + """Get the LLM usage info.""" + # Return a copy to prevent external modification + return self._llm_usage.model_copy() + + @llm_usage.setter + def llm_usage(self, value: LLMUsage): + """Set the LLM usage info.""" + self._llm_usage = value.model_copy() + + @property + def outputs(self) -> dict[str, Any]: + """Get a copy of the outputs dictionary.""" + return deepcopy(self._outputs) + + @outputs.setter + def outputs(self, value: dict[str, Any]) -> None: + """Set the outputs dictionary.""" + self._outputs = deepcopy(value) + + def set_output(self, key: str, value: Any) -> None: + """Set a single output value.""" + self._outputs[key] = deepcopy(value) + + def get_output(self, key: str, default: Any = None) -> Any: + """Get a single output value.""" + return deepcopy(self._outputs.get(key, default)) + + def update_outputs(self, updates: dict[str, Any]) -> None: + """Update multiple output values.""" + for key, value in updates.items(): + self._outputs[key] = deepcopy(value) + + @property + def node_run_steps(self) -> int: + """Get the node run steps count.""" + return self._node_run_steps + + @node_run_steps.setter + def node_run_steps(self, value: int) -> None: + """Set the node run steps count.""" + if value < 0: + raise ValueError("node_run_steps must be non-negative") + self._node_run_steps = value + + def increment_node_run_steps(self) -> None: + """Increment the node run steps by 1.""" + self._node_run_steps += 1 + + def add_tokens(self, tokens: int) -> None: + """Add tokens to the total count.""" + if tokens < 0: + raise ValueError("tokens must be non-negative") + self._total_tokens += tokens diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/core/workflow/graph_engine/event_management/event_handlers.py index 4efb5991ec..3ab69776a4 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/core/workflow/graph_engine/event_management/event_handlers.py @@ -267,10 +267,10 @@ class EventHandler: # in runtime state, rather than allowing nodes to directly access runtime state. for key, value in event.node_run_result.outputs.items(): if key == "answer": - existing = self._graph_runtime_state.outputs.get("answer", "") + existing = self._graph_runtime_state.get_output("answer", "") if existing: - self._graph_runtime_state.outputs["answer"] = f"{existing}{value}" + self._graph_runtime_state.set_output("answer", f"{existing}{value}") else: - self._graph_runtime_state.outputs["answer"] = value + self._graph_runtime_state.set_output("answer", value) else: - self._graph_runtime_state.outputs[key] = value + self._graph_runtime_state.set_output(key, value) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index dffeee66f5..11d46fe998 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -147,14 +147,14 @@ class LoopNode(Node): for key, value in graph_engine.graph_runtime_state.outputs.items(): if key == "answer": # Concatenate answer outputs with newline - existing_answer = self.graph_runtime_state.outputs.get("answer", "") + existing_answer = self.graph_runtime_state.get_output("answer", "") if existing_answer: - self.graph_runtime_state.outputs["answer"] = f"{existing_answer}{value}" + self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}") else: - self.graph_runtime_state.outputs["answer"] = value + self.graph_runtime_state.set_output("answer", value) else: # For other outputs, just update - self.graph_runtime_state.outputs[key] = value + self.graph_runtime_state.set_output(key, value) # Update the total tokens from this iteration cost_tokens += graph_engine.graph_runtime_state.total_tokens diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py new file mode 100644 index 0000000000..61c1fd3181 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -0,0 +1,114 @@ +from time import time + +import pytest + +from core.workflow.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.entities.variable_pool import VariablePool + + +class TestGraphRuntimeState: + def test_property_getters_and_setters(self): + # FIXME(-LAN-): Mock VariablePool if needed + variable_pool = VariablePool() + start_time = time() + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time) + + # Test variable_pool property + assert state.variable_pool == variable_pool + new_pool = VariablePool() + state.variable_pool = new_pool + assert state.variable_pool == new_pool + + # Test start_at property + assert state.start_at == start_time + new_time = time() + 100 + state.start_at = new_time + assert state.start_at == new_time + + # Test total_tokens property + assert state.total_tokens == 0 + state.total_tokens = 100 + assert state.total_tokens == 100 + + # Test node_run_steps property + assert state.node_run_steps == 0 + state.node_run_steps = 5 + assert state.node_run_steps == 5 + + def test_outputs_immutability(self): + variable_pool = VariablePool() + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + + # Test that getting outputs returns a copy + outputs1 = state.outputs + outputs2 = state.outputs + assert outputs1 == outputs2 + assert outputs1 is not outputs2 # Different objects + + # Test that modifying retrieved outputs doesn't affect internal state + outputs = state.outputs + outputs["test"] = "value" + assert "test" not in state.outputs + + # Test set_output method + state.set_output("key1", "value1") + assert state.get_output("key1") == "value1" + + # Test update_outputs method + state.update_outputs({"key2": "value2", "key3": "value3"}) + assert state.get_output("key2") == "value2" + assert state.get_output("key3") == "value3" + + def test_llm_usage_immutability(self): + variable_pool = VariablePool() + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + + # Test that getting llm_usage returns a copy + usage1 = state.llm_usage + usage2 = state.llm_usage + assert usage1 is not usage2 # Different objects + + def test_type_validation(self): + variable_pool = VariablePool() + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + + # Test total_tokens validation + with pytest.raises(ValueError): + state.total_tokens = -1 + + # Test node_run_steps validation + with pytest.raises(ValueError): + state.node_run_steps = -1 + + def test_helper_methods(self): + variable_pool = VariablePool() + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + + # Test increment_node_run_steps + initial_steps = state.node_run_steps + state.increment_node_run_steps() + assert state.node_run_steps == initial_steps + 1 + + # Test add_tokens + initial_tokens = state.total_tokens + state.add_tokens(50) + assert state.total_tokens == initial_tokens + 50 + + # Test add_tokens validation + with pytest.raises(ValueError): + state.add_tokens(-1) + + def test_deep_copy_for_nested_objects(self): + variable_pool = VariablePool() + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + + # Test deep copy for nested dict + nested_data = {"level1": {"level2": {"value": "test"}}} + state.set_output("nested", nested_data) + + retrieved = state.get_output("nested") + retrieved["level1"]["level2"]["value"] = "modified" + + # Original should remain unchanged + assert state.get_output("nested")["level1"]["level2"]["value"] == "test"