From dd3b1ccd45e34bf2a5115219bde6feeeb2a1919b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 28 Nov 2025 15:38:46 +0800 Subject: [PATCH] refactor(workflow): remove redundant get_base_node_data() method (#28803) --- api/core/workflow/nodes/base/node.py | 36 +++++++++---------- .../graph_engine/test_graph_engine.py | 2 +- .../workflow/nodes/code/code_node_spec.py | 6 ++-- .../nodes/iteration/iteration_node_spec.py | 6 ++-- 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index bbdd3099da..592bea0e16 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -240,23 +240,23 @@ class Node(Generic[NodeDataT]): from core.workflow.nodes.tool.tool_node import ToolNode if isinstance(self, ToolNode): - start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "") - start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "") + start_event.provider_id = getattr(self.node_data, "provider_id", "") + start_event.provider_type = getattr(self.node_data, "provider_type", "") from core.workflow.nodes.datasource.datasource_node import DatasourceNode if isinstance(self, DatasourceNode): - plugin_id = getattr(self.get_base_node_data(), "plugin_id", "") - provider_name = getattr(self.get_base_node_data(), "provider_name", "") + plugin_id = getattr(self.node_data, "plugin_id", "") + provider_name = getattr(self.node_data, "provider_name", "") start_event.provider_id = f"{plugin_id}/{provider_name}" - start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "") + start_event.provider_type = getattr(self.node_data, "provider_type", "") from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode if isinstance(self, TriggerEventNode): - start_event.provider_id = getattr(self.get_base_node_data(), "provider_id", "") - start_event.provider_type = getattr(self.get_base_node_data(), "provider_type", "") + start_event.provider_id = getattr(self.node_data, "provider_id", "") + start_event.provider_type = getattr(self.node_data, "provider_type", "") from typing import cast @@ -265,7 +265,7 @@ class Node(Generic[NodeDataT]): if isinstance(self, AgentNode): start_event.agent_strategy = AgentNodeStrategyInit( - name=cast(AgentNodeData, self.get_base_node_data()).agent_strategy_name, + name=cast(AgentNodeData, self.node_data).agent_strategy_name, icon=self.agent_strategy_icon, ) @@ -419,10 +419,6 @@ class Node(Generic[NodeDataT]): """Get the default values dictionary for this node.""" return self._node_data.default_value_dict - def get_base_node_data(self) -> BaseNodeData: - """Get the BaseNodeData object for this node.""" - return self._node_data - # Public interface properties that delegate to abstract methods @property def error_strategy(self) -> ErrorStrategy | None: @@ -548,7 +544,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, start_at=event.start_at, inputs=event.inputs, metadata=event.metadata, @@ -561,7 +557,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, index=event.index, pre_loop_output=event.pre_loop_output, ) @@ -572,7 +568,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, start_at=event.start_at, inputs=event.inputs, outputs=event.outputs, @@ -586,7 +582,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, start_at=event.start_at, inputs=event.inputs, outputs=event.outputs, @@ -601,7 +597,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, start_at=event.start_at, inputs=event.inputs, metadata=event.metadata, @@ -614,7 +610,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, index=event.index, pre_iteration_output=event.pre_iteration_output, ) @@ -625,7 +621,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, start_at=event.start_at, inputs=event.inputs, outputs=event.outputs, @@ -639,7 +635,7 @@ class Node(Generic[NodeDataT]): id=self._node_execution_id, node_id=self._node_id, node_type=self.node_type, - node_title=self.get_base_node_data().title, + node_title=self.node_data.title, start_at=event.start_at, inputs=event.inputs, outputs=event.outputs, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 4a117f8c96..02f20413e0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -744,7 +744,7 @@ def test_graph_run_emits_partial_success_when_node_failure_recovered(): ) llm_node = graph.nodes["llm"] - base_node_data = llm_node.get_base_node_data() + base_node_data = llm_node.node_data base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index f62c714820..596e72ddd0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -471,8 +471,8 @@ class TestCodeNodeInitialization: assert node._get_description() is None - def test_get_base_node_data(self): - """Test get_base_node_data returns node data.""" + def test_node_data_property(self): + """Test node_data property returns node data.""" node = CodeNode.__new__(CodeNode) node._node_data = CodeNodeData( title="Base Test", @@ -482,7 +482,7 @@ class TestCodeNodeInitialization: outputs={}, ) - result = node.get_base_node_data() + result = node.node_data assert result == node._node_data assert result.title == "Base Test" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index 51af4367f7..b67e84d1d4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -240,8 +240,8 @@ class TestIterationNodeInitialization: assert node._get_description() == "This is a description" - def test_get_base_node_data(self): - """Test get_base_node_data returns node data.""" + def test_node_data_property(self): + """Test node_data property returns node data.""" node = IterationNode.__new__(IterationNode) node._node_data = IterationNodeData( title="Base Test", @@ -249,7 +249,7 @@ class TestIterationNodeInitialization: output_selector=["y"], ) - result = node.get_base_node_data() + result = node.node_data assert result == node._node_data