From 501d3b6203b56f6df1a3dc5992e6f8e588f41615 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Tue, 24 Jun 2025 20:27:22 +0800 Subject: [PATCH] feat(api): Explicitly define version method for all BaseNode subclasses (#21443) This PR addresses issue #21441 by implementing explicit `version` method definitions for all `BaseNode` subclasses to improve code maintainability. ### Changes Added explicit `version` method definitions for all `BaseNode` subclasses: - `QuestionClassifierNode` - `KnowledgeRetrievalNode` - `AgentNode` Added comprehensive test suite to validate: 1. All subclasses of `BaseNode` have explicitly defined `version` method 2. All subclasses have required `_node_type` property 3. The `(node_type, node_version)` combination is unique across all subclasses --- api/core/workflow/nodes/agent/agent_node.py | 4 +++ .../knowledge_retrieval_node.py | 4 +++ .../question_classifier_node.py | 4 +++ .../workflow/nodes/base/test_base_node.py | 36 +++++++++++++++++++ 4 files changed, 48 insertions(+) create mode 100644 api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 22c564c1fc..2f28363955 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -39,6 +39,10 @@ class AgentNode(ToolNode): _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator: """ Run the agent node diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2995f0682f..0b9e98f28a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -71,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode): _node_data_cls = KnowledgeRetrievalNodeData # type: ignore _node_type = NodeType.KNOWLEDGE_RETRIEVAL + @classmethod + def version(cls): + return "1" + def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 1f50700c7e..a518167cc6 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -40,6 +40,10 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER + @classmethod + def version(cls): + return "1" + def _run(self): node_data = cast(QuestionClassifierNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py new file mode 100644 index 0000000000..8712b61a23 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -0,0 +1,36 @@ +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import NodeType + +# Ensures that all node classes are imported. +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + +_ = NODE_TYPE_CLASSES_MAPPING + + +def _get_all_subclasses(root: type[BaseNode]) -> list[type[BaseNode]]: + subclasses = [] + queue = [root] + while queue: + cls = queue.pop() + + subclasses.extend(cls.__subclasses__()) + queue.extend(cls.__subclasses__()) + + return subclasses + + +def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined(): + classes = _get_all_subclasses(BaseNode) # type: ignore + type_version_set: set[tuple[NodeType, str]] = set() + + for cls in classes: + # Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__ + assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)" + node_type = cls._node_type + node_version = cls.version() + + assert isinstance(cls._node_type, NodeType) + assert isinstance(node_version, str) + node_type_and_version = (node_type, node_version) + assert node_type_and_version not in type_version_set + type_version_set.add(node_type_and_version)