mirror of https://github.com/langgenius/dify.git
feat: add property-based access control to GraphRuntimeState
- Replace direct field access with private attributes and property decorators - Implement deep copy protection for mutable objects (dict, LLMUsage) - Add helper methods: set_output(), get_output(), update_outputs() - Add increment_node_run_steps() and add_tokens() convenience methods - Update loop_node and event_handlers to use new accessor methods - Add comprehensive unit tests for immutability and validation - Ensure backward compatibility with existing property access patterns
This commit is contained in:
parent
9c96b23d55
commit
fe3f03e50a
|
|
@ -1,6 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
|
@ -8,21 +9,132 @@ from .variable_pool import VariablePool
|
|||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
# Private attributes to prevent direct modification
|
||||
_variable_pool: VariablePool = PrivateAttr()
|
||||
_start_at: float = PrivateAttr()
|
||||
_total_tokens: int = PrivateAttr(default=0)
|
||||
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
|
||||
_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
_node_run_steps: int = PrivateAttr(default=0)
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
def __init__(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
start_at: float,
|
||||
total_tokens: int = 0,
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
|
||||
#
|
||||
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
|
||||
# after a serialization and deserialization round trip.
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
# Initialize private attributes with validation
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
self._start_at = start_at
|
||||
|
||||
if total_tokens < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = total_tokens
|
||||
|
||||
if llm_usage is None:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
self._llm_usage = llm_usage
|
||||
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
self._outputs = deepcopy(outputs)
|
||||
|
||||
if node_run_steps < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = node_run_steps
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> VariablePool:
|
||||
"""Get the variable pool."""
|
||||
return self._variable_pool
|
||||
|
||||
@variable_pool.setter
|
||||
def variable_pool(self, value: VariablePool) -> None:
|
||||
"""Set the variable pool."""
|
||||
self._variable_pool = value
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
"""Get the start time."""
|
||||
return self._start_at
|
||||
|
||||
@start_at.setter
|
||||
def start_at(self, value: float) -> None:
|
||||
"""Set the start time."""
|
||||
self._start_at = value
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Get the total tokens count."""
|
||||
return self._total_tokens
|
||||
|
||||
@total_tokens.setter
|
||||
def total_tokens(self, value: int):
|
||||
"""Set the total tokens count."""
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be non-negative")
|
||||
self._total_tokens = value
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Get the LLM usage info."""
|
||||
# Return a copy to prevent external modification
|
||||
return self._llm_usage.model_copy()
|
||||
|
||||
@llm_usage.setter
|
||||
def llm_usage(self, value: LLMUsage):
|
||||
"""Set the LLM usage info."""
|
||||
self._llm_usage = value.model_copy()
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, Any]:
|
||||
"""Get a copy of the outputs dictionary."""
|
||||
return deepcopy(self._outputs)
|
||||
|
||||
@outputs.setter
|
||||
def outputs(self, value: dict[str, Any]) -> None:
|
||||
"""Set the outputs dictionary."""
|
||||
self._outputs = deepcopy(value)
|
||||
|
||||
def set_output(self, key: str, value: Any) -> None:
|
||||
"""Set a single output value."""
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
def get_output(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a single output value."""
|
||||
return deepcopy(self._outputs.get(key, default))
|
||||
|
||||
def update_outputs(self, updates: dict[str, Any]) -> None:
|
||||
"""Update multiple output values."""
|
||||
for key, value in updates.items():
|
||||
self._outputs[key] = deepcopy(value)
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
"""Get the node run steps count."""
|
||||
return self._node_run_steps
|
||||
|
||||
@node_run_steps.setter
|
||||
def node_run_steps(self, value: int) -> None:
|
||||
"""Set the node run steps count."""
|
||||
if value < 0:
|
||||
raise ValueError("node_run_steps must be non-negative")
|
||||
self._node_run_steps = value
|
||||
|
||||
def increment_node_run_steps(self) -> None:
|
||||
"""Increment the node run steps by 1."""
|
||||
self._node_run_steps += 1
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
"""Add tokens to the total count."""
|
||||
if tokens < 0:
|
||||
raise ValueError("tokens must be non-negative")
|
||||
self._total_tokens += tokens
|
||||
|
|
|
|||
|
|
@ -267,10 +267,10 @@ class EventHandler:
|
|||
# in runtime state, rather than allowing nodes to directly access runtime state.
|
||||
for key, value in event.node_run_result.outputs.items():
|
||||
if key == "answer":
|
||||
existing = self._graph_runtime_state.outputs.get("answer", "")
|
||||
existing = self._graph_runtime_state.get_output("answer", "")
|
||||
if existing:
|
||||
self._graph_runtime_state.outputs["answer"] = f"{existing}{value}"
|
||||
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
|
||||
else:
|
||||
self._graph_runtime_state.outputs["answer"] = value
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.outputs[key] = value
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
|
|
|
|||
|
|
@ -147,14 +147,14 @@ class LoopNode(Node):
|
|||
for key, value in graph_engine.graph_runtime_state.outputs.items():
|
||||
if key == "answer":
|
||||
# Concatenate answer outputs with newline
|
||||
existing_answer = self.graph_runtime_state.outputs.get("answer", "")
|
||||
existing_answer = self.graph_runtime_state.get_output("answer", "")
|
||||
if existing_answer:
|
||||
self.graph_runtime_state.outputs["answer"] = f"{existing_answer}{value}"
|
||||
self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}")
|
||||
else:
|
||||
self.graph_runtime_state.outputs["answer"] = value
|
||||
self.graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
# For other outputs, just update
|
||||
self.graph_runtime_state.outputs[key] = value
|
||||
self.graph_runtime_state.set_output(key, value)
|
||||
|
||||
# Update the total tokens from this iteration
|
||||
cost_tokens += graph_engine.graph_runtime_state.total_tokens
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
from time import time
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class TestGraphRuntimeState:
|
||||
def test_property_getters_and_setters(self):
|
||||
# FIXME(-LAN-): Mock VariablePool if needed
|
||||
variable_pool = VariablePool()
|
||||
start_time = time()
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time)
|
||||
|
||||
# Test variable_pool property
|
||||
assert state.variable_pool == variable_pool
|
||||
new_pool = VariablePool()
|
||||
state.variable_pool = new_pool
|
||||
assert state.variable_pool == new_pool
|
||||
|
||||
# Test start_at property
|
||||
assert state.start_at == start_time
|
||||
new_time = time() + 100
|
||||
state.start_at = new_time
|
||||
assert state.start_at == new_time
|
||||
|
||||
# Test total_tokens property
|
||||
assert state.total_tokens == 0
|
||||
state.total_tokens = 100
|
||||
assert state.total_tokens == 100
|
||||
|
||||
# Test node_run_steps property
|
||||
assert state.node_run_steps == 0
|
||||
state.node_run_steps = 5
|
||||
assert state.node_run_steps == 5
|
||||
|
||||
def test_outputs_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting outputs returns a copy
|
||||
outputs1 = state.outputs
|
||||
outputs2 = state.outputs
|
||||
assert outputs1 == outputs2
|
||||
assert outputs1 is not outputs2 # Different objects
|
||||
|
||||
# Test that modifying retrieved outputs doesn't affect internal state
|
||||
outputs = state.outputs
|
||||
outputs["test"] = "value"
|
||||
assert "test" not in state.outputs
|
||||
|
||||
# Test set_output method
|
||||
state.set_output("key1", "value1")
|
||||
assert state.get_output("key1") == "value1"
|
||||
|
||||
# Test update_outputs method
|
||||
state.update_outputs({"key2": "value2", "key3": "value3"})
|
||||
assert state.get_output("key2") == "value2"
|
||||
assert state.get_output("key3") == "value3"
|
||||
|
||||
def test_llm_usage_immutability(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test that getting llm_usage returns a copy
|
||||
usage1 = state.llm_usage
|
||||
usage2 = state.llm_usage
|
||||
assert usage1 is not usage2 # Different objects
|
||||
|
||||
def test_type_validation(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test total_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.total_tokens = -1
|
||||
|
||||
# Test node_run_steps validation
|
||||
with pytest.raises(ValueError):
|
||||
state.node_run_steps = -1
|
||||
|
||||
def test_helper_methods(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test increment_node_run_steps
|
||||
initial_steps = state.node_run_steps
|
||||
state.increment_node_run_steps()
|
||||
assert state.node_run_steps == initial_steps + 1
|
||||
|
||||
# Test add_tokens
|
||||
initial_tokens = state.total_tokens
|
||||
state.add_tokens(50)
|
||||
assert state.total_tokens == initial_tokens + 50
|
||||
|
||||
# Test add_tokens validation
|
||||
with pytest.raises(ValueError):
|
||||
state.add_tokens(-1)
|
||||
|
||||
def test_deep_copy_for_nested_objects(self):
|
||||
variable_pool = VariablePool()
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
|
||||
# Test deep copy for nested dict
|
||||
nested_data = {"level1": {"level2": {"value": "test"}}}
|
||||
state.set_output("nested", nested_data)
|
||||
|
||||
retrieved = state.get_output("nested")
|
||||
retrieved["level1"]["level2"]["value"] = "modified"
|
||||
|
||||
# Original should remain unchanged
|
||||
assert state.get_output("nested")["level1"]["level2"]["value"] == "test"
|
||||
Loading…
Reference in New Issue