Refactor: centralize node data hydration (#27771)

This commit is contained in:
-LAN- 2025-11-27 15:41:56 +08:00 committed by GitHub
parent 1b733abe82
commit 13bf6547ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 381 additions and 899 deletions

View File

@ -152,10 +152,5 @@ class CodeExecutor:
raise CodeExecutionError(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs)
try:
response = cls.execute_code(language, preload, runner)
except CodeExecutionError as e:
raise e
return template_transformer.transform_response(response)

View File

@ -26,7 +26,6 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
@ -40,7 +39,6 @@ from core.workflow.node_events import (
StreamCompletedEvent,
)
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
@ -66,7 +64,7 @@ if TYPE_CHECKING:
from core.plugin.entities.request import InvokeCredentials
class AgentNode(Node):
class AgentNode(Node[AgentNodeData]):
"""
Agent Node
"""
@ -74,27 +72,6 @@ class AgentNode(Node):
node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -2,42 +2,20 @@ from collections.abc import Mapping, Sequence
from typing import Any
from core.variables import ArrayFileSegment, FileSegment, Segment
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
class AnswerNode(Node):
class AnswerNode(Node[AnswerNodeData]):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.RESPONSE
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -2,7 +2,7 @@ import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from functools import singledispatchmethod
from typing import Any, ClassVar
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
@ -49,12 +49,121 @@ from models.enums import UserFrom
from .entities import BaseNodeData, RetryConfig
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
logger = logging.getLogger(__name__)
class Node:
class Node(Generic[NodeDataT]):
node_type: ClassVar["NodeType"]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
def __init_subclass__(cls, **kwargs: Any) -> None:
"""
Automatically extract and validate the node data type from the generic parameter.
When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method:
1. Inspects `__orig_bases__` to find the `Node[T]` parameterization
2. Extracts `T` (e.g., `MyNodeData`) from the generic argument
3. Validates that `T` is a proper `BaseNodeData` subclass
4. Stores it in `_node_data_type` for automatic hydration in `__init__`
This eliminates the need for subclasses to manually implement boilerplate
accessor methods like `_get_title()`, `_get_error_strategy()`, etc.
How it works:
::
class CodeNode(Node[CodeNodeData]):
__orig_bases__ = ( CodeNodeData(BaseNodeData)
Node[CodeNodeData], title: str
) desc: str | None
...
get_origin(base) -> Node
get_args(base) -> (
CodeNodeData,
)
Validate:
- Is it a type?
- Is it a BaseNodeData
subclass?
cls._node_data_type =
CodeNodeData
Later, in __init__:
::
config["data"] _hydrate_node_data() _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
node_type = NodeType.CODE
# No need to implement _get_title, _get_error_strategy, etc.
"""
super().__init_subclass__(**kwargs)
if cls is Node:
return
node_data_type = cls._extract_node_data_type_from_generic()
if node_data_type is None:
raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype")
cls._node_data_type = node_data_type
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
"""
Extract the node data type from the generic parameter `Node[T]`.
Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`.
Returns:
The extracted BaseNodeData subtype, or None if not found.
Raises:
TypeError: If the generic argument is invalid (not exactly one argument,
or not a BaseNodeData subtype).
"""
# __orig_bases__ contains the original generic bases before type erasure.
# For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`.
for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined]
origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]`
if origin is Node:
args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]`
if len(args) != 1:
raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument")
candidate = args[0]
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
return candidate
return None
def __init__(
self,
@ -63,6 +172,7 @@ class Node:
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
self._graph_init_params = graph_init_params
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@ -83,8 +193,24 @@ class Node:
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
self.post_init()
def post_init(self) -> None:
"""Optional hook for subclasses requiring extra initialization."""
return
@property
def graph_init_params(self) -> "GraphInitParams":
return self._graph_init_params
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
@ -273,38 +399,29 @@ class Node:
def retry(self) -> bool:
return False
# Abstract methods that subclasses must implement to provide access
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
...
return self._node_data.error_strategy
@abstractmethod
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
...
return self._node_data.retry_config
@abstractmethod
def _get_title(self) -> str:
"""Get the node title."""
...
return self._node_data.title
@abstractmethod
def _get_description(self) -> str | None:
"""Get the node description."""
...
return self._node_data.desc
@abstractmethod
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
...
return self._node_data.default_value_dict
@abstractmethod
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
...
return self._node_data
# Public interface properties that delegate to abstract methods
@property
@ -332,6 +449,11 @@ class Node:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()
@property
def node_data(self) -> NodeDataT:
"""Typed access to this node's configuration data."""
return self._node_data
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
match result.status:
case WorkflowNodeExecutionStatus.FAILED:

View File

@ -9,9 +9,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.variables.segments import ArrayFileSegment
from core.variables.types import SegmentType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
@ -22,32 +21,11 @@ from .exc import (
)
class CodeNode(Node):
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""

View File

@ -20,9 +20,8 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
@ -38,7 +37,7 @@ from .entities import DatasourceNodeData
from .exc import DatasourceNodeError, DatasourceParameterError
class DatasourceNode(Node):
class DatasourceNode(Node[DatasourceNodeData]):
"""
Datasource Node
"""
@ -47,27 +46,6 @@ class DatasourceNode(Node):
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DatasourceNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> Generator:
"""
Run the datasource node

View File

@ -25,9 +25,8 @@ from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import DocumentExtractorNodeData
@ -36,7 +35,7 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
class DocumentExtractorNode(Node):
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
@ -46,27 +45,6 @@ class DocumentExtractorNode(Node):
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -1,41 +1,16 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.entities import EndNodeData
class EndNode(Node):
class EndNode(Node[EndNodeData]):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -7,10 +7,10 @@ from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from factories import file_factory
@ -31,32 +31,11 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
class HttpRequestNode(Node):
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {

View File

@ -2,15 +2,14 @@ from collections.abc import Mapping
from typing import Any
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import HumanInputNodeData
class HumanInputNode(Node):
class HumanInputNode(Node[HumanInputNodeData]):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
@ -28,31 +27,10 @@ class HumanInputNode(Node):
_node_data: HumanInputNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HumanInputNodeData(**data)
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()

View File

@ -3,9 +3,8 @@ from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.runtime import VariablePool
@ -13,33 +12,12 @@ from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(Node):
class IfElseNode(Node[IfElseNodeData]):
node_type = NodeType.IF_ELSE
execution_type = NodeExecutionType.BRANCH
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -14,7 +14,6 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
@ -36,7 +35,6 @@ from core.workflow.node_events import (
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
@ -60,7 +58,7 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(LLMUsageTrackingMixin, Node):
class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
"""
Iteration Node.
"""
@ -69,27 +67,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {

View File

@ -1,14 +1,10 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(Node):
class IterationStartNode(Node[IterationStartNodeData]):
"""
Iteration Start Node.
"""
@ -17,27 +13,6 @@ class IterationStartNode(Node):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -10,9 +10,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
@ -35,32 +34,11 @@ default_retrieval_model = {
}
class KnowledgeIndexNode(Node):
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
_node_data: KnowledgeIndexNodeData
node_type = NodeType.KNOWLEDGE_INDEX
execution_type = NodeExecutionType.RESPONSE
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeIndexNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self) -> NodeRunResult: # type: ignore
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool

View File

@ -30,14 +30,12 @@ from core.variables import (
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
@ -82,7 +80,7 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
@ -118,27 +116,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"

View File

@ -1,12 +1,11 @@
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, TypeVar
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import FilterOperator, ListOperatorNodeData, Order
@ -35,32 +34,11 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
return wrapper
class ListOperatorNode(Node):
class ListOperatorNode(Node[ListOperatorNodeData]):
node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -55,7 +55,6 @@ from core.variables import (
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
@ -69,7 +68,7 @@ from core.workflow.node_events import (
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
@ -100,7 +99,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LLMNode(Node):
class LLMNode(Node[LLMNodeData]):
node_type = NodeType.LLM
_node_data: LLMNodeData
@ -139,27 +138,6 @@ class LLMNode(Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -1,14 +1,10 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(Node):
class LoopEndNode(Node[LoopEndNodeData]):
"""
Loop End Node.
"""
@ -17,27 +13,6 @@ class LoopEndNode(Node):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import Segment, SegmentType
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
@ -29,7 +28,6 @@ from core.workflow.node_events import (
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
from core.workflow.utils.condition.processor import ConditionProcessor
@ -42,7 +40,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(LLMUsageTrackingMixin, Node):
class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
"""
Loop Node.
"""
@ -51,27 +49,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -1,14 +1,10 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(Node):
class LoopStartNode(Node[LoopStartNodeData]):
"""
Loop Start Node.
"""
@ -17,27 +13,6 @@ class LoopStartNode(Node):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -69,17 +69,9 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
node_instance = node_class(
return node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Initialize node with provided data
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_instance.init_node_data(node_data)
return node_instance

View File

@ -27,10 +27,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
@ -84,7 +83,7 @@ def extract_json(text):
return None
class ParameterExtractorNode(Node):
class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
Parameter Extractor Node.
"""
@ -93,27 +92,6 @@ class ParameterExtractorNode(Node):
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None

View File

@ -13,14 +13,13 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
@ -44,7 +43,7 @@ if TYPE_CHECKING:
from core.workflow.runtime import GraphRuntimeState
class QuestionClassifierNode(Node):
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
node_type = NodeType.QUESTION_CLASSIFIER
execution_type = NodeExecutionType.BRANCH
@ -78,27 +77,6 @@ class QuestionClassifierNode(Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"

View File

@ -1,41 +1,16 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(Node):
class StartNode(Node[StartNodeData]):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -3,41 +3,19 @@ from typing import Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node):
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""

View File

@ -16,14 +16,12 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
@ -42,7 +40,7 @@ if TYPE_CHECKING:
from core.workflow.runtime import VariablePool
class ToolNode(Node):
class ToolNode(Node[ToolNodeData]):
"""
Tool Node
"""
@ -51,9 +49,6 @@ class ToolNode(Node):
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
return "1"
@ -498,24 +493,6 @@ class ToolNode(Node):
return result
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -1,43 +1,18 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import TriggerEventNodeData
class TriggerEventNode(Node):
class TriggerEventNode(Node[TriggerEventNodeData]):
node_type = NodeType.TRIGGER_PLUGIN
execution_type = NodeExecutionType.ROOT
_node_data: TriggerEventNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerEventNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {

View File

@ -1,42 +1,17 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
class TriggerScheduleNode(Node):
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
node_type = NodeType.TRIGGER_SCHEDULE
execution_type = NodeExecutionType.ROOT
_node_data: TriggerScheduleNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TriggerScheduleNodeData(**data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -3,41 +3,17 @@ from typing import Any
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.enums import NodeExecutionType, NodeType
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import ContentType, WebhookData
class TriggerWebhookNode(Node):
class TriggerWebhookNode(Node[WebhookData]):
node_type = NodeType.TRIGGER_WEBHOOK
execution_type = NodeExecutionType.ROOT
_node_data: WebhookData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = WebhookData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {

View File

@ -1,40 +1,17 @@
from collections.abc import Mapping
from typing import Any
from core.variables.segments import Segment
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(Node):
class VariableAggregatorNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -5,9 +5,8 @@ from core.variables import SegmentType, Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
@ -22,33 +21,12 @@ if TYPE_CHECKING:
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(Node):
class VariableAssignerNode(Node[VariableAssignerData]):
node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def __init__(
self,
id: str,

View File

@ -7,9 +7,8 @@ from core.variables import SegmentType, Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
@ -51,32 +50,11 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector
class VariableAssignerNode(Node):
class VariableAssignerNode(Node[VariableAssignerNodeData]):
node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.

View File

@ -159,7 +159,6 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(node_config_data)
try:
# variable selector to variable mapping
@ -303,7 +302,6 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(node_data)
try:
# variable selector to variable mapping

View File

@ -69,10 +69,6 @@ def init_code_node(code_config: dict):
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in code_config:
node.init_node_data(code_config["data"])
return node

View File

@ -65,10 +65,6 @@ def init_http_node(config: dict):
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
@ -709,10 +705,6 @@ def test_nested_object_variable_selector(setup_http_mock):
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in graph_config["nodes"][1]:
node.init_node_data(graph_config["nodes"][1]["data"])
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")

View File

@ -82,10 +82,6 @@ def init_llm_node(config: dict) -> LLMNode:
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node

View File

@ -85,7 +85,6 @@ def init_parameter_extractor_node(config: dict):
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
return node

View File

@ -82,7 +82,6 @@ def test_execute_code(setup_code_executor_mock):
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
# execute node
result = node._run()

View File

@ -62,7 +62,6 @@ def init_tool_node(config: dict):
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
return node

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import time
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
import pytest
@ -12,14 +11,19 @@ from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import Graph
from core.workflow.graph.validation import GraphValidationError
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
class _TestNode(Node):
class _TestNodeData(BaseNodeData):
type: NodeType | str | None = None
execution_type: NodeExecutionType | str | None = None
class _TestNode(Node[_TestNodeData]):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.EXECUTABLE
@ -41,31 +45,8 @@ class _TestNode(Node):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
data = config.get("data", {})
if isinstance(data, Mapping):
execution_type = data.get("execution_type")
if isinstance(execution_type, str):
self.execution_type = NodeExecutionType(execution_type)
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
self.data: dict[str, object] = {}
def init_node_data(self, data: Mapping[str, object]) -> None:
title = str(data.get("title", self.id))
desc = data.get("description")
error_strategy_value = data.get("error_strategy")
error_strategy: ErrorStrategy | None = None
if isinstance(error_strategy_value, ErrorStrategy):
error_strategy = error_strategy_value
elif isinstance(error_strategy_value, str):
error_strategy = ErrorStrategy(error_strategy_value)
self._base_node_data = BaseNodeData(
title=title,
desc=str(desc) if desc is not None else None,
error_strategy=error_strategy,
)
self.data = dict(data)
node_type_value = data.get("type")
node_type_value = self.data.get("type")
if isinstance(node_type_value, NodeType):
self.node_type = node_type_value
elif isinstance(node_type_value, str):
@ -77,23 +58,19 @@ class _TestNode(Node):
def _run(self):
raise NotImplementedError
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._base_node_data.error_strategy
def post_init(self) -> None:
super().post_init()
self._maybe_override_execution_type()
self.data = dict(self.node_data.model_dump())
def _get_retry_config(self) -> RetryConfig:
return self._base_node_data.retry_config
def _get_title(self) -> str:
return self._base_node_data.title
def _get_description(self) -> str | None:
return self._base_node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._base_node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._base_node_data
def _maybe_override_execution_type(self) -> None:
execution_type_value = self.node_data.execution_type
if execution_type_value is None:
return
if isinstance(execution_type_value, NodeExecutionType):
self.execution_type = execution_type_value
else:
self.execution_type = NodeExecutionType(execution_type_value)
@dataclass(slots=True)
@ -109,7 +86,6 @@ class _SimpleNodeFactory:
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
node.init_node_data(node_config.get("data", {}))
return node

View File

@ -32,7 +32,7 @@ def test_abort_command():
# Create mock nodes with required attributes - using shared runtime state
start_node = StartNode(
id="start",
config={"id": "start"},
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@ -45,7 +45,6 @@ def test_abort_command():
),
graph_runtime_state=shared_runtime_state,
)
start_node.init_node_data({"title": "start", "variables": []})
mock_graph.nodes["start"] = start_node
# Mock graph methods
@ -142,7 +141,7 @@ def test_pause_command():
start_node = StartNode(
id="start",
config={"id": "start"},
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@ -155,7 +154,6 @@ def test_pause_command():
),
graph_runtime_state=shared_runtime_state,
)
start_node.init_node_data({"title": "start", "variables": []})
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])

View File

@ -63,7 +63,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
@ -88,7 +87,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
@ -105,7 +103,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
@ -125,7 +122,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
@ -142,7 +138,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()

View File

@ -62,7 +62,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
@ -87,7 +86,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
@ -104,7 +102,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
@ -123,7 +120,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_node.init_node_data(end_config["data"])
graph = (
Graph.new()

View File

@ -62,7 +62,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
@ -87,7 +86,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
@ -118,7 +116,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if_else_node.init_node_data(if_else_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
@ -138,7 +135,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
@ -155,7 +151,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()

View File

@ -111,9 +111,6 @@ class MockNodeFactory(DifyNodeFactory):
mock_config=self.mock_config,
)
# Initialize node with provided data
mock_instance.init_node_data(node_data)
return mock_instance
# For non-mocked node types, use parent implementation

View File

@ -142,6 +142,8 @@ def test_mock_loop_node_preserves_config():
"start_node_id": "node1",
"loop_variables": [],
"outputs": {},
"break_conditions": [],
"logical_operator": "and",
},
}

View File

@ -63,7 +63,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -125,7 +124,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -184,7 +182,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -246,7 +243,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -311,7 +307,6 @@ class TestMockCodeNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -376,7 +371,6 @@ class TestMockCodeNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -445,7 +439,6 @@ class TestMockCodeNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()

View File

@ -83,9 +83,6 @@ def test_execute_answer():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@ -1,4 +1,7 @@
import pytest
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
# Ensures that all node classes are imported.
@ -7,6 +10,12 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
_ = NODE_TYPE_CLASSES_MAPPING
class _TestNodeData(BaseNodeData):
"""Test node data for unit tests."""
pass
def _get_all_subclasses(root: type[Node]) -> list[type[Node]]:
subclasses = []
queue = [root]
@ -34,3 +43,79 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
node_type_and_version = (node_type, node_version)
assert node_type_and_version not in type_version_set
type_version_set.add(node_type_and_version)
def test_extract_node_data_type_from_generic_extracts_type():
"""When a class inherits from Node[T], it should extract T."""
class _ConcreteNode(Node[_TestNodeData]):
node_type = NodeType.CODE
@staticmethod
def version() -> str:
return "1"
result = _ConcreteNode._extract_node_data_type_from_generic()
assert result is _TestNodeData
def test_extract_node_data_type_from_generic_returns_none_for_base_node():
"""The base Node class itself should return None (no generic parameter)."""
result = Node._extract_node_data_type_from_generic()
assert result is None
def test_extract_node_data_type_from_generic_raises_for_non_base_node_data():
"""When generic parameter is not a BaseNodeData subtype, should raise TypeError."""
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
class _InvalidNode(Node[str]): # type: ignore[type-arg]
pass
def test_extract_node_data_type_from_generic_raises_for_non_type():
"""When generic parameter is not a concrete type, should raise TypeError."""
from typing import TypeVar
T = TypeVar("T")
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
class _InvalidNode(Node[T]): # type: ignore[type-arg]
pass
def test_init_subclass_raises_without_generic_or_explicit_type():
"""A subclass must either use Node[T] or explicitly set _node_data_type."""
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
class _InvalidNode(Node):
pass
def test_init_subclass_rejects_explicit_node_data_type_without_generic():
"""Setting _node_data_type explicitly cannot bypass the Node[T] requirement."""
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
class _ExplicitNode(Node):
_node_data_type = _TestNodeData
node_type = NodeType.CODE
@staticmethod
def version() -> str:
return "1"
def test_init_subclass_sets_node_data_type_from_generic():
"""Verify that __init_subclass__ sets _node_data_type from the generic parameter."""
class _AutoNode(Node[_TestNodeData]):
node_type = NodeType.CODE
@staticmethod
def version() -> str:
return "1"
assert _AutoNode._node_data_type is _TestNodeData

View File

@ -111,8 +111,6 @@ def llm_node(
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@ -498,8 +496,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node, mock_file_saver

View File

@ -0,0 +1,74 @@
from collections.abc import Mapping
import pytest
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
class _SampleNodeData(BaseNodeData):
foo: str
class _SampleNode(Node[_SampleNodeData]):
node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "sample-test"
def _run(self):
raise NotImplementedError
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}),
start_at=0.0,
)
return init_params, runtime_state
def test_node_hydrates_data_during_initialization():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
id="node-1",
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
assert node.node_data.foo == "bar"
assert node.title == "Sample"
def test_missing_generic_argument_raises_type_error():
graph_config: dict[str, object] = {}
with pytest.raises(TypeError):
class _InvalidNode(Node): # type: ignore[type-abstract]
node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
raise NotImplementedError

View File

@ -50,8 +50,6 @@ def document_extractor_node(graph_init_params):
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node

View File

@ -114,9 +114,6 @@ def test_execute_if_else_result_true():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@ -187,9 +184,6 @@ def test_execute_if_else_result_false():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@ -252,9 +246,6 @@ def test_array_file_contains_file_name():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
@ -347,7 +338,6 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
graph_runtime_state=graph_runtime_state,
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
@ -417,7 +407,6 @@ def test_execute_if_else_boolean_false_conditions():
"data": node_data,
},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
@ -487,7 +476,6 @@ def test_execute_if_else_boolean_cases_structure():
graph_runtime_state=graph_runtime_state,
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()

View File

@ -57,8 +57,6 @@ def list_operator_node():
graph_init_params=graph_init_params,
graph_runtime_state=MagicMock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node

View File

@ -73,7 +73,6 @@ def tool_node(monkeypatch) -> "ToolNode":
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config["data"])
return node

View File

@ -101,9 +101,6 @@ def test_overwrite_string_variable():
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
@ -203,9 +200,6 @@ def test_append_variable_to_array():
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
@ -296,9 +290,6 @@ def test_clear_array():
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,

View File

@ -139,11 +139,6 @@ def test_remove_first_from_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
# Run the node
result = list(node.run())
@ -228,10 +223,6 @@ def test_remove_last_from_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])
@ -313,10 +304,6 @@ def test_remove_first_from_empty_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])
@ -398,10 +385,6 @@ def test_remove_last_from_empty_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])

View File

@ -47,7 +47,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
),
)
node.init_node_data(node_config["data"])
return node