mirror of https://github.com/langgenius/dify.git
feat(api): automatically `NODE_TYPE_CLASSES_MAPPING` generation from node class definitions (#28525)
This commit is contained in:
parent
2f8cb2a1af
commit
d162f7e5ef
|
|
@ -51,6 +51,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_commands,
|
ext_commands,
|
||||||
ext_compress,
|
ext_compress,
|
||||||
ext_database,
|
ext_database,
|
||||||
|
ext_forward_refs,
|
||||||
ext_hosting_provider,
|
ext_hosting_provider,
|
||||||
ext_import_modules,
|
ext_import_modules,
|
||||||
ext_logging,
|
ext_logging,
|
||||||
|
|
@ -75,6 +76,7 @@ def initialize_extensions(app: DifyApp):
|
||||||
ext_warnings,
|
ext_warnings,
|
||||||
ext_import_modules,
|
ext_import_modules,
|
||||||
ext_orjson,
|
ext_orjson,
|
||||||
|
ext_forward_refs,
|
||||||
ext_set_secretkey,
|
ext_set_secretkey,
|
||||||
ext_compress,
|
ext_compress,
|
||||||
ext_code_based_extension,
|
ext_code_based_extension,
|
||||||
|
|
|
||||||
|
|
@ -130,7 +130,7 @@ class AppGenerateEntity(BaseModel):
|
||||||
# extra parameters, like: auto_generate_conversation_name
|
# extra parameters, like: auto_generate_conversation_name
|
||||||
extras: dict[str, Any] = Field(default_factory=dict)
|
extras: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
# tracing instance
|
# tracing instance; use forward ref to avoid circular import at import time
|
||||||
trace_manager: Optional["TraceQueueManager"] = None
|
trace_manager: Optional["TraceQueueManager"] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -275,16 +275,23 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||||
start_node_id: str | None = None
|
start_node_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# Import TraceQueueManager at runtime to resolve forward references
|
# NOTE: Avoid importing heavy tracing modules at import time to prevent circular imports.
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
# Forward reference to TraceQueueManager is kept as a string; we rebuild with a stub now to
|
||||||
|
# avoid Pydantic forward-ref errors in test contexts, and with the real class at app startup.
|
||||||
|
|
||||||
# Rebuild models that use forward references
|
|
||||||
AppGenerateEntity.model_rebuild()
|
# Minimal stub to satisfy Pydantic model_rebuild in environments where the real type is not importable yet.
|
||||||
EasyUIBasedAppGenerateEntity.model_rebuild()
|
class _TraceQueueManagerStub:
|
||||||
ConversationAppGenerateEntity.model_rebuild()
|
pass
|
||||||
ChatAppGenerateEntity.model_rebuild()
|
|
||||||
CompletionAppGenerateEntity.model_rebuild()
|
|
||||||
AgentChatAppGenerateEntity.model_rebuild()
|
_ns = {"TraceQueueManager": _TraceQueueManagerStub}
|
||||||
AdvancedChatAppGenerateEntity.model_rebuild()
|
AppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
WorkflowAppGenerateEntity.model_rebuild()
|
EasyUIBasedAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
RagPipelineGenerateEntity.model_rebuild()
|
ConversationAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
|
ChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
|
CompletionAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
|
AgentChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
|
AdvancedChatAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
|
WorkflowAppGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
|
RagPipelineGenerateEntity.model_rebuild(_types_namespace=_ns)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
|
import pkgutil
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
|
from types import MappingProxyType
|
||||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
@ -134,6 +138,34 @@ class Node(Generic[NodeDataT]):
|
||||||
|
|
||||||
cls._node_data_type = node_data_type
|
cls._node_data_type = node_data_type
|
||||||
|
|
||||||
|
# Skip base class itself
|
||||||
|
if cls is Node:
|
||||||
|
return
|
||||||
|
# Only register production node implementations defined under core.workflow.nodes.*
|
||||||
|
# This prevents test helper subclasses from polluting the global registry and
|
||||||
|
# accidentally overriding real node types (e.g., a test Answer node).
|
||||||
|
module_name = getattr(cls, "__module__", "")
|
||||||
|
# Only register concrete subclasses that define node_type and version()
|
||||||
|
node_type = cls.node_type
|
||||||
|
version = cls.version()
|
||||||
|
bucket = Node._registry.setdefault(node_type, {})
|
||||||
|
if module_name.startswith("core.workflow.nodes."):
|
||||||
|
# Production node definitions take precedence and may override
|
||||||
|
bucket[version] = cls # type: ignore[index]
|
||||||
|
else:
|
||||||
|
# External/test subclasses may register but must not override production
|
||||||
|
bucket.setdefault(version, cls) # type: ignore[index]
|
||||||
|
# Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic
|
||||||
|
version_keys = [v for v in bucket if v != "latest"]
|
||||||
|
numeric_pairs: list[tuple[str, int]] = []
|
||||||
|
for v in version_keys:
|
||||||
|
numeric_pairs.append((v, int(v)))
|
||||||
|
if numeric_pairs:
|
||||||
|
latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0]
|
||||||
|
else:
|
||||||
|
latest_key = max(version_keys) if version_keys else version
|
||||||
|
bucket["latest"] = bucket[latest_key]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||||
"""
|
"""
|
||||||
|
|
@ -165,6 +197,9 @@ class Node(Generic[NodeDataT]):
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Global registry populated via __init_subclass__
|
||||||
|
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
|
|
@ -395,6 +430,29 @@ class Node(Generic[NodeDataT]):
|
||||||
# in `api/core/workflow/nodes/__init__.py`.
|
# in `api/core/workflow/nodes/__init__.py`.
|
||||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_node_type_classes_mapping(cls) -> Mapping["NodeType", Mapping[str, type["Node"]]]:
|
||||||
|
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||||
|
|
||||||
|
Import all modules under core.workflow.nodes so subclasses register themselves on import.
|
||||||
|
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||||
|
"""
|
||||||
|
# Import all node modules to ensure they are loaded (thus registered)
|
||||||
|
import core.workflow.nodes as _nodes_pkg
|
||||||
|
|
||||||
|
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
|
||||||
|
# Avoid importing modules that depend on the registry to prevent circular imports
|
||||||
|
# e.g. node_factory imports node_mapping which builds the mapping here.
|
||||||
|
if _modname in {
|
||||||
|
"core.workflow.nodes.node_factory",
|
||||||
|
"core.workflow.nodes.node_mapping",
|
||||||
|
}:
|
||||||
|
continue
|
||||||
|
importlib.import_module(_modname)
|
||||||
|
|
||||||
|
# Return a readonly view so callers can't mutate the registry by accident
|
||||||
|
return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -1,165 +1,9 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.code import CodeNode
|
|
||||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
|
||||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
|
||||||
from core.workflow.nodes.end.end_node import EndNode
|
|
||||||
from core.workflow.nodes.http_request import HttpRequestNode
|
|
||||||
from core.workflow.nodes.human_input import HumanInputNode
|
|
||||||
from core.workflow.nodes.if_else import IfElseNode
|
|
||||||
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
|
||||||
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
|
|
||||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
|
||||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
|
||||||
from core.workflow.nodes.llm import LLMNode
|
|
||||||
from core.workflow.nodes.loop import LoopEndNode, LoopNode, LoopStartNode
|
|
||||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
|
||||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
|
||||||
from core.workflow.nodes.start import StartNode
|
|
||||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
|
||||||
from core.workflow.nodes.tool import ToolNode
|
|
||||||
from core.workflow.nodes.trigger_plugin import TriggerEventNode
|
|
||||||
from core.workflow.nodes.trigger_schedule import TriggerScheduleNode
|
|
||||||
from core.workflow.nodes.trigger_webhook import TriggerWebhookNode
|
|
||||||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
|
||||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
|
||||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
|
||||||
|
|
||||||
LATEST_VERSION = "latest"
|
LATEST_VERSION = "latest"
|
||||||
|
|
||||||
# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
|
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes
|
||||||
# Specifically, if you have introduced new node types, you should add them here.
|
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||||
#
|
|
||||||
# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
|
|
||||||
# hook. Try to avoid duplication of node information.
|
|
||||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
|
||||||
NodeType.START: {
|
|
||||||
LATEST_VERSION: StartNode,
|
|
||||||
"1": StartNode,
|
|
||||||
},
|
|
||||||
NodeType.END: {
|
|
||||||
LATEST_VERSION: EndNode,
|
|
||||||
"1": EndNode,
|
|
||||||
},
|
|
||||||
NodeType.ANSWER: {
|
|
||||||
LATEST_VERSION: AnswerNode,
|
|
||||||
"1": AnswerNode,
|
|
||||||
},
|
|
||||||
NodeType.LLM: {
|
|
||||||
LATEST_VERSION: LLMNode,
|
|
||||||
"1": LLMNode,
|
|
||||||
},
|
|
||||||
NodeType.KNOWLEDGE_RETRIEVAL: {
|
|
||||||
LATEST_VERSION: KnowledgeRetrievalNode,
|
|
||||||
"1": KnowledgeRetrievalNode,
|
|
||||||
},
|
|
||||||
NodeType.IF_ELSE: {
|
|
||||||
LATEST_VERSION: IfElseNode,
|
|
||||||
"1": IfElseNode,
|
|
||||||
},
|
|
||||||
NodeType.CODE: {
|
|
||||||
LATEST_VERSION: CodeNode,
|
|
||||||
"1": CodeNode,
|
|
||||||
},
|
|
||||||
NodeType.TEMPLATE_TRANSFORM: {
|
|
||||||
LATEST_VERSION: TemplateTransformNode,
|
|
||||||
"1": TemplateTransformNode,
|
|
||||||
},
|
|
||||||
NodeType.QUESTION_CLASSIFIER: {
|
|
||||||
LATEST_VERSION: QuestionClassifierNode,
|
|
||||||
"1": QuestionClassifierNode,
|
|
||||||
},
|
|
||||||
NodeType.HTTP_REQUEST: {
|
|
||||||
LATEST_VERSION: HttpRequestNode,
|
|
||||||
"1": HttpRequestNode,
|
|
||||||
},
|
|
||||||
NodeType.TOOL: {
|
|
||||||
LATEST_VERSION: ToolNode,
|
|
||||||
# This is an issue that caused problems before.
|
|
||||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
|
||||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
|
||||||
"2": ToolNode,
|
|
||||||
"1": ToolNode,
|
|
||||||
},
|
|
||||||
NodeType.VARIABLE_AGGREGATOR: {
|
|
||||||
LATEST_VERSION: VariableAggregatorNode,
|
|
||||||
"1": VariableAggregatorNode,
|
|
||||||
},
|
|
||||||
NodeType.LEGACY_VARIABLE_AGGREGATOR: {
|
|
||||||
LATEST_VERSION: VariableAggregatorNode,
|
|
||||||
"1": VariableAggregatorNode,
|
|
||||||
}, # original name of VARIABLE_AGGREGATOR
|
|
||||||
NodeType.ITERATION: {
|
|
||||||
LATEST_VERSION: IterationNode,
|
|
||||||
"1": IterationNode,
|
|
||||||
},
|
|
||||||
NodeType.ITERATION_START: {
|
|
||||||
LATEST_VERSION: IterationStartNode,
|
|
||||||
"1": IterationStartNode,
|
|
||||||
},
|
|
||||||
NodeType.LOOP: {
|
|
||||||
LATEST_VERSION: LoopNode,
|
|
||||||
"1": LoopNode,
|
|
||||||
},
|
|
||||||
NodeType.LOOP_START: {
|
|
||||||
LATEST_VERSION: LoopStartNode,
|
|
||||||
"1": LoopStartNode,
|
|
||||||
},
|
|
||||||
NodeType.LOOP_END: {
|
|
||||||
LATEST_VERSION: LoopEndNode,
|
|
||||||
"1": LoopEndNode,
|
|
||||||
},
|
|
||||||
NodeType.PARAMETER_EXTRACTOR: {
|
|
||||||
LATEST_VERSION: ParameterExtractorNode,
|
|
||||||
"1": ParameterExtractorNode,
|
|
||||||
},
|
|
||||||
NodeType.VARIABLE_ASSIGNER: {
|
|
||||||
LATEST_VERSION: VariableAssignerNodeV2,
|
|
||||||
"1": VariableAssignerNodeV1,
|
|
||||||
"2": VariableAssignerNodeV2,
|
|
||||||
},
|
|
||||||
NodeType.DOCUMENT_EXTRACTOR: {
|
|
||||||
LATEST_VERSION: DocumentExtractorNode,
|
|
||||||
"1": DocumentExtractorNode,
|
|
||||||
},
|
|
||||||
NodeType.LIST_OPERATOR: {
|
|
||||||
LATEST_VERSION: ListOperatorNode,
|
|
||||||
"1": ListOperatorNode,
|
|
||||||
},
|
|
||||||
NodeType.AGENT: {
|
|
||||||
LATEST_VERSION: AgentNode,
|
|
||||||
# This is an issue that caused problems before.
|
|
||||||
# Logically, we shouldn't use two different versions to point to the same class here,
|
|
||||||
# but in order to maintain compatibility with historical data, this approach has been retained.
|
|
||||||
"2": AgentNode,
|
|
||||||
"1": AgentNode,
|
|
||||||
},
|
|
||||||
NodeType.HUMAN_INPUT: {
|
|
||||||
LATEST_VERSION: HumanInputNode,
|
|
||||||
"1": HumanInputNode,
|
|
||||||
},
|
|
||||||
NodeType.DATASOURCE: {
|
|
||||||
LATEST_VERSION: DatasourceNode,
|
|
||||||
"1": DatasourceNode,
|
|
||||||
},
|
|
||||||
NodeType.KNOWLEDGE_INDEX: {
|
|
||||||
LATEST_VERSION: KnowledgeIndexNode,
|
|
||||||
"1": KnowledgeIndexNode,
|
|
||||||
},
|
|
||||||
NodeType.TRIGGER_WEBHOOK: {
|
|
||||||
LATEST_VERSION: TriggerWebhookNode,
|
|
||||||
"1": TriggerWebhookNode,
|
|
||||||
},
|
|
||||||
NodeType.TRIGGER_PLUGIN: {
|
|
||||||
LATEST_VERSION: TriggerEventNode,
|
|
||||||
"1": TriggerEventNode,
|
|
||||||
},
|
|
||||||
NodeType.TRIGGER_SCHEDULE: {
|
|
||||||
LATEST_VERSION: TriggerScheduleNode,
|
|
||||||
"1": TriggerScheduleNode,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.errors import ToolInvokeError
|
from core.tools.errors import ToolInvokeError
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
|
||||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
|
|
@ -430,7 +429,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||||
}
|
}
|
||||||
if usage.total_tokens > 0:
|
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
|
||||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||||
|
|
@ -449,8 +448,17 @@ class ToolNode(Node[ToolNodeData]):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage:
|
||||||
if isinstance(tool_runtime, WorkflowTool):
|
# Avoid importing WorkflowTool at module import time; rely on duck typing
|
||||||
return tool_runtime.latest_usage
|
# Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes.
|
||||||
|
latest = getattr(tool_runtime, "latest_usage", None)
|
||||||
|
# Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects
|
||||||
|
# for any name, so we must type-check here.
|
||||||
|
if isinstance(latest, LLMUsage):
|
||||||
|
return latest
|
||||||
|
if isinstance(latest, dict):
|
||||||
|
# Allow dict payloads from external runtimes
|
||||||
|
return LLMUsage.model_validate(latest)
|
||||||
|
# Fallback to empty usage when attribute is missing or not a valid payload
|
||||||
return LLMUsage.empty_usage()
|
return LLMUsage.empty_usage()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from dify_app import DifyApp
|
||||||
|
|
||||||
|
|
||||||
|
def is_enabled() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def init_app(app: DifyApp):
|
||||||
|
"""Resolve Pydantic forward refs that would otherwise cause circular imports.
|
||||||
|
|
||||||
|
Rebuilds models in core.app.entities.app_invoke_entities with the real TraceQueueManager type.
|
||||||
|
Safe to run multiple times.
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
try:
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
AgentChatAppGenerateEntity,
|
||||||
|
AppGenerateEntity,
|
||||||
|
ChatAppGenerateEntity,
|
||||||
|
CompletionAppGenerateEntity,
|
||||||
|
ConversationAppGenerateEntity,
|
||||||
|
EasyUIBasedAppGenerateEntity,
|
||||||
|
RagPipelineGenerateEntity,
|
||||||
|
WorkflowAppGenerateEntity,
|
||||||
|
)
|
||||||
|
from core.ops.ops_trace_manager import TraceQueueManager # heavy import, do it at startup only
|
||||||
|
|
||||||
|
ns = {"TraceQueueManager": TraceQueueManager}
|
||||||
|
for Model in (
|
||||||
|
AppGenerateEntity,
|
||||||
|
EasyUIBasedAppGenerateEntity,
|
||||||
|
ConversationAppGenerateEntity,
|
||||||
|
ChatAppGenerateEntity,
|
||||||
|
CompletionAppGenerateEntity,
|
||||||
|
AgentChatAppGenerateEntity,
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
WorkflowAppGenerateEntity,
|
||||||
|
RagPipelineGenerateEntity,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
Model.model_rebuild(_types_namespace=ns)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("model_rebuild skipped for %s: %s", Model.__name__, e)
|
||||||
|
except Exception as e:
|
||||||
|
# Don't block app startup; just log at debug level.
|
||||||
|
logger.debug("ext_forward_refs init skipped: %s", e)
|
||||||
|
|
@ -29,7 +29,7 @@ class _TestNode(Node[_TestNodeData]):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "test"
|
return "1"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ class MockLLMNode(MockNodeMixin, LLMNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock LLM node."""
|
"""Execute mock LLM node."""
|
||||||
|
|
@ -189,7 +189,7 @@ class MockAgentNode(MockNodeMixin, AgentNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock agent node."""
|
"""Execute mock agent node."""
|
||||||
|
|
@ -241,7 +241,7 @@ class MockToolNode(MockNodeMixin, ToolNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock tool node."""
|
"""Execute mock tool node."""
|
||||||
|
|
@ -294,7 +294,7 @@ class MockKnowledgeRetrievalNode(MockNodeMixin, KnowledgeRetrievalNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock knowledge retrieval node."""
|
"""Execute mock knowledge retrieval node."""
|
||||||
|
|
@ -351,7 +351,7 @@ class MockHttpRequestNode(MockNodeMixin, HttpRequestNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock HTTP request node."""
|
"""Execute mock HTTP request node."""
|
||||||
|
|
@ -404,7 +404,7 @@ class MockQuestionClassifierNode(MockNodeMixin, QuestionClassifierNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock question classifier node."""
|
"""Execute mock question classifier node."""
|
||||||
|
|
@ -452,7 +452,7 @@ class MockParameterExtractorNode(MockNodeMixin, ParameterExtractorNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock parameter extractor node."""
|
"""Execute mock parameter extractor node."""
|
||||||
|
|
@ -502,7 +502,7 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""Execute mock document extractor node."""
|
"""Execute mock document extractor node."""
|
||||||
|
|
@ -557,7 +557,7 @@ class MockIterationNode(MockNodeMixin, IterationNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _create_graph_engine(self, index: int, item: Any):
|
def _create_graph_engine(self, index: int, item: Any):
|
||||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||||
|
|
@ -632,7 +632,7 @@ class MockLoopNode(MockNodeMixin, LoopNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _create_graph_engine(self, start_at, root_node_id: str):
|
def _create_graph_engine(self, start_at, root_node_id: str):
|
||||||
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
"""Create a graph engine with MockNodeFactory instead of DifyNodeFactory."""
|
||||||
|
|
@ -694,7 +694,7 @@ class MockTemplateTransformNode(MockNodeMixin, TemplateTransformNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""Execute mock template transform node."""
|
"""Execute mock template transform node."""
|
||||||
|
|
@ -780,7 +780,7 @@ class MockCodeNode(MockNodeMixin, CodeNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
"""Return the version of this mock node."""
|
"""Return the version of this mock node."""
|
||||||
return "mock-1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
"""Execute mock code node."""
|
"""Execute mock code node."""
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,10 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||||
type_version_set: set[tuple[NodeType, str]] = set()
|
type_version_set: set[tuple[NodeType, str]] = set()
|
||||||
|
|
||||||
for cls in classes:
|
for cls in classes:
|
||||||
|
# Only validate production node classes; skip test-defined subclasses and external helpers
|
||||||
|
module_name = getattr(cls, "__module__", "")
|
||||||
|
if not module_name.startswith("core."):
|
||||||
|
continue
|
||||||
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
||||||
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
||||||
node_type = cls.node_type
|
node_type = cls.node_type
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,84 @@
|
||||||
|
import types
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
|
# Import concrete nodes we will assert on (numeric version path)
|
||||||
|
from core.workflow.nodes.variable_assigner.v1.node import (
|
||||||
|
VariableAssignerNode as VariableAssignerV1,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.variable_assigner.v2.node import (
|
||||||
|
VariableAssignerNode as VariableAssignerV2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_assigner_latest_prefers_highest_numeric_version():
|
||||||
|
# Act
|
||||||
|
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||||
|
|
||||||
|
# Assert basic presence
|
||||||
|
assert NodeType.VARIABLE_ASSIGNER in mapping
|
||||||
|
va_versions = mapping[NodeType.VARIABLE_ASSIGNER]
|
||||||
|
|
||||||
|
# Both concrete versions must be present
|
||||||
|
assert va_versions.get("1") is VariableAssignerV1
|
||||||
|
assert va_versions.get("2") is VariableAssignerV2
|
||||||
|
|
||||||
|
# And latest should point to numerically-highest version ("2")
|
||||||
|
assert va_versions.get("latest") is VariableAssignerV2
|
||||||
|
|
||||||
|
|
||||||
|
def test_latest_prefers_highest_numeric_version():
|
||||||
|
# Arrange: define two ephemeral subclasses with numeric versions under a NodeType
|
||||||
|
# that has no concrete implementations in production to avoid interference.
|
||||||
|
class _Version1(Node[BaseNodeData]): # type: ignore[misc]
|
||||||
|
node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR
|
||||||
|
|
||||||
|
def init_node_data(self, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
def _get_error_strategy(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_retry_config(self):
|
||||||
|
return types.SimpleNamespace() # not used
|
||||||
|
|
||||||
|
def _get_title(self) -> str:
|
||||||
|
return "version1"
|
||||||
|
|
||||||
|
def _get_description(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_default_value_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_base_node_data(self):
|
||||||
|
return types.SimpleNamespace(title="version1")
|
||||||
|
|
||||||
|
class _Version2(_Version1): # type: ignore[misc]
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "2"
|
||||||
|
|
||||||
|
def _get_title(self) -> str:
|
||||||
|
return "version2"
|
||||||
|
|
||||||
|
# Act: build a fresh mapping (it should now see our ephemeral subclasses)
|
||||||
|
mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||||
|
|
||||||
|
# Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version
|
||||||
|
assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping
|
||||||
|
legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR]
|
||||||
|
|
||||||
|
assert legacy_versions.get("1") is _Version1
|
||||||
|
assert legacy_versions.get("2") is _Version2
|
||||||
|
assert legacy_versions.get("latest") is _Version2
|
||||||
|
|
@ -19,7 +19,7 @@ class _SampleNode(Node[_SampleNodeData]):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "sample-test"
|
return "1"
|
||||||
|
|
||||||
def _run(self):
|
def _run(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue