From d162f7e5ef0db74d3396239c82e6283732f043ae Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 1 Dec 2025 14:14:19 +0800 Subject: [PATCH] feat(api): automatically `NODE_TYPE_CLASSES_MAPPING` generation from node class definitions (#28525) --- api/app_factory.py | 2 + api/core/app/entities/app_invoke_entities.py | 33 ++-- api/core/workflow/nodes/base/node.py | 58 +++++++ api/core/workflow/nodes/node_mapping.py | 160 +----------------- api/core/workflow/nodes/tool/tool_node.py | 16 +- api/extensions/ext_forward_refs.py | 49 ++++++ .../workflow/graph/test_graph_validation.py | 2 +- .../workflow/graph_engine/test_mock_nodes.py | 24 +-- .../workflow/nodes/base/test_base_node.py | 4 + .../test_get_node_type_classes_mapping.py | 84 +++++++++ .../core/workflow/nodes/test_base_node.py | 2 +- 11 files changed, 245 insertions(+), 189 deletions(-) create mode 100644 api/extensions/ext_forward_refs.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py diff --git a/api/app_factory.py b/api/app_factory.py index 933cf294d1..ad2065682c 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_forward_refs, ext_hosting_provider, ext_import_modules, ext_logging, @@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp): ext_warnings, ext_import_modules, ext_orjson, + ext_forward_refs, ext_set_secretkey, ext_compress, ext_code_based_extension, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 5143dbf1e8..81c355eb10 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -130,7 +130,7 @@ class AppGenerateEntity(BaseModel): # extra parameters, like: auto_generate_conversation_name extras: dict[str, Any] = Field(default_factory=dict) - # tracing instance + # tracing instance; use forward ref to avoid circular import at import time trace_manager: Optional["TraceQueueManager"] = None @@ -275,16 +275,23 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): start_node_id: str | None = None -# Import TraceQueueManager at runtime to resolve forward references -from core.ops.ops_trace_manager import TraceQueueManager +# NOTE: Avoid importing heavy tracing modules at import time to prevent circular imports. +# Forward reference to TraceQueueManager is kept as a string; we rebuild with a stub now to +# avoid Pydantic forward-ref errors in test contexts, and with the real class at app startup. -# Rebuild models that use forward references -AppGenerateEntity.model_rebuild() -EasyUIBasedAppGenerateEntity.model_rebuild() -ConversationAppGenerateEntity.model_rebuild() -ChatAppGenerateEntity.model_rebuild() -CompletionAppGenerateEntity.model_rebuild() -AgentChatAppGenerateEntity.model_rebuild() -AdvancedChatAppGenerateEntity.model_rebuild() -WorkflowAppGenerateEntity.model_rebuild() -RagPipelineGenerateEntity.model_rebuild() + +# Minimal stub to satisfy Pydantic model_rebuild in environments where the real type is not importable yet. +class _TraceQueueManagerStub: + pass + + +_ns = {"TraceQueueManager": _TraceQueueManagerStub} +AppGenerateEntity.model_rebuild(_types_namespace=_ns) +EasyUIBasedAppGenerateEntity.model_rebuild(_types_namespace=_ns) +ConversationAppGenerateEntity.model_rebuild(_types_namespace=_ns) +ChatAppGenerateEntity.model_rebuild(_types_namespace=_ns) +CompletionAppGenerateEntity.model_rebuild(_types_namespace=_ns) +AgentChatAppGenerateEntity.model_rebuild(_types_namespace=_ns) +AdvancedChatAppGenerateEntity.model_rebuild(_types_namespace=_ns) +WorkflowAppGenerateEntity.model_rebuild(_types_namespace=_ns) +RagPipelineGenerateEntity.model_rebuild(_types_namespace=_ns) diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 592bea0e16..c2e1105971 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,7 +1,11 @@ +import importlib import logging +import operator +import pkgutil from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod +from types import MappingProxyType from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from uuid import uuid4 @@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]): cls._node_data_type = node_data_type + # Skip base class itself + if cls is Node: + return + # Only register production node implementations defined under core.workflow.nodes.* + # This prevents test helper subclasses from polluting the global registry and + # accidentally overriding real node types (e.g., a test Answer node). + module_name = getattr(cls, "__module__", "") + # Only register concrete subclasses that define node_type and version() + node_type = cls.node_type + version = cls.version() + bucket = Node._registry.setdefault(node_type, {}) + if module_name.startswith("core.workflow.nodes."): + # Production node definitions take precedence and may override + bucket[version] = cls # type: ignore[index] + else: + # External/test subclasses may register but must not override production + bucket.setdefault(version, cls) # type: ignore[index] + # Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic + version_keys = [v for v in bucket if v != "latest"] + numeric_pairs: list[tuple[str, int]] = [] + for v in version_keys: + numeric_pairs.append((v, int(v))) + if numeric_pairs: + latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0] + else: + latest_key = max(version_keys) if version_keys else version + bucket["latest"] = bucket[latest_key] + @classmethod def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: """ @@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]): return None + # Global registry populated via __init_subclass__ + _registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {} + def __init__( self, id: str, @@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]): # in `api/core/workflow/nodes/__init__.py`. raise NotImplementedError("subclasses of BaseNode must implement `version` method.") + @classmethod + def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]: + """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. + + Import all modules under core.workflow.nodes so subclasses register themselves on import. + Then we return a readonly view of the registry to avoid accidental mutation. + """ + # Import all node modules to ensure they are loaded (thus registered) + import core.workflow.nodes as _nodes_pkg + + for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): + # Avoid importing modules that depend on the registry to prevent circular imports + # e.g. node_factory imports node_mapping which builds the mapping here. + if _modname in { + "core.workflow.nodes.node_factory", + "core.workflow.nodes.node_mapping", + }: + continue + importlib.import_module(_modname) + + # Return a readonly view so callers can't mutate the registry by accident + return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()} + @property def retry(self) -> bool: return False diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index b926645f18..85df543a2a 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -1,165 +1,9 @@ from collections.abc import Mapping from core.workflow.enums import NodeType -from core.workflow.nodes.agent.agent_node import AgentNode -from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base.node import Node -from core.workflow.nodes.code import CodeNode -from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from core.workflow.nodes.document_extractor import DocumentExtractorNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request import HttpRequestNode -from core.workflow.nodes.human_input import HumanInputNode -from core.workflow.nodes.if_else import IfElseNode -from core.workflow.nodes.iteration import IterationNode, IterationStartNode -from core.workflow.nodes.knowledge_index import KnowledgeIndexNode -from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode -from core.workflow.nodes.list_operator import ListOperatorNode -from core.workflow.nodes.llm import LLMNode -from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode -from core.workflow.nodes.parameter_extractor import ParameterExtractorNode -from core.workflow.nodes.question_classifier import QuestionClassifierNode -from core.workflow.nodes.start import StartNode -from core.workflow.nodes.template_transform import TemplateTransformNode -from core.workflow.nodes.tool import ToolNode -from core.workflow.nodes.trigger_plugin import TriggerEventNode -from core.workflow.nodes.trigger_schedule import TriggerScheduleNode -from core.workflow.nodes.trigger_webhook import TriggerWebhookNode -from core.workflow.nodes.variable_aggregator import VariableAggregatorNode -from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1 -from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2 LATEST_VERSION = "latest" -# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode. -# Specifically, if you have introduced new node types, you should add them here. -# -# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__` -# hook. Try to avoid duplication of node information. -NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = { - NodeType.START: { - LATEST_VERSION: StartNode, - "1": StartNode, - }, - NodeType.END: { - LATEST_VERSION: EndNode, - "1": EndNode, - }, - NodeType.ANSWER: { - LATEST_VERSION: AnswerNode, - "1": AnswerNode, - }, - NodeType.LLM: { - LATEST_VERSION: LLMNode, - "1": LLMNode, - }, - NodeType.KNOWLEDGE_RETRIEVAL: { - LATEST_VERSION: KnowledgeRetrievalNode, - "1": KnowledgeRetrievalNode, - }, - NodeType.IF_ELSE: { - LATEST_VERSION: IfElseNode, - "1": IfElseNode, - }, - NodeType.CODE: { - LATEST_VERSION: CodeNode, - "1": CodeNode, - }, - NodeType.TEMPLATE_TRANSFORM: { - LATEST_VERSION: TemplateTransformNode, - "1": TemplateTransformNode, - }, - NodeType.QUESTION_CLASSIFIER: { - LATEST_VERSION: QuestionClassifierNode, - "1": QuestionClassifierNode, - }, - NodeType.HTTP_REQUEST: { - LATEST_VERSION: HttpRequestNode, - "1": HttpRequestNode, - }, - NodeType.TOOL: { - LATEST_VERSION: ToolNode, - # This is an issue that caused problems before. - # Logically, we shouldn't use two different versions to point to the same class here, - # but in order to maintain compatibility with historical data, this approach has been retained. - "2": ToolNode, - "1": ToolNode, - }, - NodeType.VARIABLE_AGGREGATOR: { - LATEST_VERSION: VariableAggregatorNode, - "1": VariableAggregatorNode, - }, - NodeType.LEGACY_VARIABLE_AGGREGATOR: { - LATEST_VERSION: VariableAggregatorNode, - "1": VariableAggregatorNode, - }, # original name of VARIABLE_AGGREGATOR - NodeType.ITERATION: { - LATEST_VERSION: IterationNode, - "1": IterationNode, - }, - NodeType.ITERATION_START: { - LATEST_VERSION: IterationStartNode, - "1": IterationStartNode, - }, - NodeType.LOOP: { - LATEST_VERSION: LoopNode, - "1": LoopNode, - }, - NodeType.LOOP_START: { - LATEST_VERSION: LoopStartNode, - "1": LoopStartNode, - }, - NodeType.LOOP_END: { - LATEST_VERSION: LoopEndNode, - "1": LoopEndNode, - }, - NodeType.PARAMETER_EXTRACTOR: { - LATEST_VERSION: ParameterExtractorNode, - "1": ParameterExtractorNode, - }, - NodeType.VARIABLE_ASSIGNER: { - LATEST_VERSION: VariableAssignerNodeV2, - "1": VariableAssignerNodeV1, - "2": VariableAssignerNodeV2, - }, - NodeType.DOCUMENT_EXTRACTOR: { - LATEST_VERSION: DocumentExtractorNode, - "1": DocumentExtractorNode, - }, - NodeType.LIST_OPERATOR: { - LATEST_VERSION: ListOperatorNode, - "1": ListOperatorNode, - }, - NodeType.AGENT: { - LATEST_VERSION: AgentNode, - # This is an issue that caused problems before. - # Logically, we shouldn't use two different versions to point to the same class here, - # but in order to maintain compatibility with historical data, this approach has been retained. - "2": AgentNode, - "1": AgentNode, - }, - NodeType.HUMAN_INPUT: { - LATEST_VERSION: HumanInputNode, - "1": HumanInputNode, - }, - NodeType.DATASOURCE: { - LATEST_VERSION: DatasourceNode, - "1": DatasourceNode, - }, - NodeType.KNOWLEDGE_INDEX: { - LATEST_VERSION: KnowledgeIndexNode, - "1": KnowledgeIndexNode, - }, - NodeType.TRIGGER_WEBHOOK: { - LATEST_VERSION: TriggerWebhookNode, - "1": TriggerWebhookNode, - }, - NodeType.TRIGGER_PLUGIN: { - LATEST_VERSION: TriggerEventNode, - "1": TriggerEventNode, - }, - NodeType.TRIGGER_SCHEDULE: { - LATEST_VERSION: TriggerScheduleNode, - "1": TriggerScheduleNode, - }, -} +# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes +NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index d8536474b1..2e7ec757b4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -12,7 +12,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.tools.workflow_as_tool.tool import WorkflowTool from core.variables.segments import ArrayAnySegment, ArrayFileSegment from core.variables.variables import ArrayAnyVariable from core.workflow.enums import ( @@ -430,7 +429,7 @@ class ToolNode(Node[ToolNodeData]): metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, } - if usage.total_tokens > 0: + if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency @@ -449,8 +448,17 @@ class ToolNode(Node[ToolNodeData]): @staticmethod def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - if isinstance(tool_runtime, WorkflowTool): - return tool_runtime.latest_usage + # Avoid importing WorkflowTool at module import time; rely on duck typing + # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. + latest = getattr(tool_runtime, "latest_usage", None) + # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects + # for any name, so we must type-check here. + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + # Allow dict payloads from external runtimes + return LLMUsage.model_validate(latest) + # Fallback to empty usage when attribute is missing or not a valid payload return LLMUsage.empty_usage() @classmethod diff --git a/api/extensions/ext_forward_refs.py b/api/extensions/ext_forward_refs.py new file mode 100644 index 0000000000..c40b505b16 --- /dev/null +++ b/api/extensions/ext_forward_refs.py @@ -0,0 +1,49 @@ +import logging + +from dify_app import DifyApp + + +def is_enabled() -> bool: + return True + + +def init_app(app: DifyApp): + """Resolve Pydantic forward refs that would otherwise cause circular imports. + + Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type. + Safe to run multiple times. + """ + logger = logging.getLogger(__name__) + try: + from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, + AppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + ConversationAppGenerateEntity, + EasyUIBasedAppGenerateEntity, + RagPipelineGenerateEntity, + WorkflowAppGenerateEntity, + ) + from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only + + ns = {"TraceQueueManager": TraceQueueManager} + for Model in ( + AppGenerateEntity, + EasyUIBasedAppGenerateEntity, + ConversationAppGenerateEntity, + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + WorkflowAppGenerateEntity, + RagPipelineGenerateEntity, + ): + try: + Model.model_rebuild(_types_namespace=ns) + except Exception as e: + logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e) + except Exception as e: + # Don't block app startup; just log at debug level. + logger.debug("ext_forward_refs init skipped: %s", e) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index 2597a3d65a..5716aae4c7 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]): @classmethod def version(cls) -> str: - return "test" + return "1" def __init__( self, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 68f57ee9fb..fd94a5e833 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock LLM node.""" @@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock agent node.""" @@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock tool node.""" @@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock knowledge retrieval node.""" @@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock HTTP request node.""" @@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock question classifier node.""" @@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock parameter extractor node.""" @@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> Generator: """Execute mock document extractor node.""" @@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" @@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" @@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> NodeRunResult: """Execute mock template transform node.""" @@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode): @classmethod def version(cls) -> str: """Return the version of this mock node.""" - return "mock-1" + return "1" def _run(self) -> NodeRunResult: """Execute mock code node.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 6eead80ac9..488b47761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined type_version_set: set[tuple[NodeType, str]] = set() for cls in classes: + # Only validate production node classes; skip test-defined subclasses and external helpers + module_name = getattr(cls, "__module__", "") + if not module_name.startswith("core."): + continue # Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__ assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)" node_type = cls.node_type diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py new file mode 100644 index 0000000000..45d222b98c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -0,0 +1,84 @@ +import types +from collections.abc import Mapping + +from core.workflow.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.node import Node + +# Import concrete nodes we will assert on (numeric version path) +from core.workflow.nodes.variable_assigner.v1.node import ( + VariableAssignerNode as VariableAssignerV1, +) +from core.workflow.nodes.variable_assigner.v2.node import ( + VariableAssignerNode as VariableAssignerV2, +) + + +def test_variable_assigner_latest_prefers_highest_numeric_version(): + # Act + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + + # Assert basic presence + assert NodeType.VARIABLE_ASSIGNER in mapping + va_versions = mapping[NodeType.VARIABLE_ASSIGNER] + + # Both concrete versions must be present + assert va_versions.get("1") is VariableAssignerV1 + assert va_versions.get("2") is VariableAssignerV2 + + # And latest should point to numerically-highest version ("2") + assert va_versions.get("latest") is VariableAssignerV2 + + +def test_latest_prefers_highest_numeric_version(): + # Arrange: define two ephemeral subclasses with numeric versions under a NodeType + # that has no concrete implementations in production to avoid interference. + class _Version1(Node[BaseNodeData]): # type: ignore[misc] + node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR + + def init_node_data(self, data): + pass + + def _run(self): + raise NotImplementedError + + @classmethod + def version(cls) -> str: + return "1" + + def _get_error_strategy(self): + return None + + def _get_retry_config(self): + return types.SimpleNamespace() # not used + + def _get_title(self) -> str: + return "version1" + + def _get_description(self): + return None + + def _get_default_value_dict(self): + return {} + + def get_base_node_data(self): + return types.SimpleNamespace(title="version1") + + class _Version2(_Version1): # type: ignore[misc] + @classmethod + def version(cls) -> str: + return "2" + + def _get_title(self) -> str: + return "version2" + + # Act: build a fresh mapping (it should now see our ephemeral subclasses) + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + + # Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version + assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping + legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR] + + assert legacy_versions.get("1") is _Version1 + assert legacy_versions.get("2") is _Version2 + assert legacy_versions.get("latest") is _Version2 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 4a57ab2b89..1854cca236 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]): @classmethod def version(cls) -> str: - return "sample-test" + return "1" def _run(self): raise NotImplementedError