mirror of https://github.com/langgenius/dify.git
Refactor: centralize node data hydration (#27771)
This commit is contained in:
parent
1b733abe82
commit
13bf6547ee
|
|
@ -152,10 +152,5 @@ class CodeExecutor:
|
|||
raise CodeExecutionError(f"Unsupported language {language}")
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from core.tools.tool_manager import ToolManager
|
|||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
|
|
@ -40,7 +39,6 @@ from core.workflow.node_events import (
|
|||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -66,7 +64,7 @@ if TYPE_CHECKING:
|
|||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node):
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
|
@ -74,27 +72,6 @@ class AgentNode(Node):
|
|||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -2,42 +2,20 @@ from collections.abc import Mapping, Sequence
|
|||
from typing import Any
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(Node):
|
||||
class AnswerNode(Node[AnswerNodeData]):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -49,12 +49,121 @@ from models.enums import UserFrom
|
|||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node:
|
||||
class Node(Generic[NodeDataT]):
|
||||
node_type: ClassVar["NodeType"]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""
|
||||
Automatically extract and validate the node data type from the generic parameter.
|
||||
|
||||
When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method:
|
||||
1. Inspects `__orig_bases__` to find the `Node[T]` parameterization
|
||||
2. Extracts `T` (e.g., `MyNodeData`) from the generic argument
|
||||
3. Validates that `T` is a proper `BaseNodeData` subclass
|
||||
4. Stores it in `_node_data_type` for automatic hydration in `__init__`
|
||||
|
||||
This eliminates the need for subclasses to manually implement boilerplate
|
||||
accessor methods like `_get_title()`, `_get_error_strategy()`, etc.
|
||||
|
||||
How it works:
|
||||
::
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
│ │
|
||||
│ └─────────────────────────────────┐
|
||||
│ │
|
||||
▼ ▼
|
||||
┌─────────────────────────────┐ ┌─────────────────────────────────┐
|
||||
│ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │
|
||||
│ Node[CodeNodeData], │ │ title: str │
|
||||
│ ) │ │ desc: str | None │
|
||||
└──────────────┬──────────────┘ │ ... │
|
||||
│ └─────────────────────────────────┘
|
||||
▼ ▲
|
||||
┌─────────────────────────────┐ │
|
||||
│ get_origin(base) -> Node │ │
|
||||
│ get_args(base) -> ( │ │
|
||||
│ CodeNodeData, │ ──────────────────────┘
|
||||
│ ) │
|
||||
└──────────────┬──────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────┐
|
||||
│ Validate: │
|
||||
│ - Is it a type? │
|
||||
│ - Is it a BaseNodeData │
|
||||
│ subclass? │
|
||||
└──────────────┬──────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────┐
|
||||
│ cls._node_data_type = │
|
||||
│ CodeNodeData │
|
||||
└─────────────────────────────┘
|
||||
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
node_type = NodeType.CODE
|
||||
# No need to implement _get_title, _get_error_strategy, etc.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if cls is Node:
|
||||
return
|
||||
|
||||
node_data_type = cls._extract_node_data_type_from_generic()
|
||||
|
||||
if node_data_type is None:
|
||||
raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype")
|
||||
|
||||
cls._node_data_type = node_data_type
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
"""
|
||||
Extract the node data type from the generic parameter `Node[T]`.
|
||||
|
||||
Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`.
|
||||
|
||||
Returns:
|
||||
The extracted BaseNodeData subtype, or None if not found.
|
||||
|
||||
Raises:
|
||||
TypeError: If the generic argument is invalid (not exactly one argument,
|
||||
or not a BaseNodeData subtype).
|
||||
"""
|
||||
# __orig_bases__ contains the original generic bases before type erasure.
|
||||
# For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`.
|
||||
for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined]
|
||||
origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]`
|
||||
if origin is Node:
|
||||
args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]`
|
||||
if len(args) != 1:
|
||||
raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument")
|
||||
|
||||
candidate = args[0]
|
||||
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
|
||||
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
|
||||
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -63,6 +172,7 @@ class Node:
|
|||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
self._graph_init_params = graph_init_params
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
|
|
@ -83,8 +193,24 @@ class Node:
|
|||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""Optional hook for subclasses requiring extra initialization."""
|
||||
return
|
||||
|
||||
@property
|
||||
def graph_init_params(self) -> "GraphInitParams":
|
||||
return self._graph_init_params
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
|
|
@ -273,38 +399,29 @@ class Node:
|
|||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
return self._node_data.retry_config
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
return self._node_data.title
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
...
|
||||
return self._node_data.desc
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
return self._node_data
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
|
|
@ -332,6 +449,11 @@ class Node:
|
|||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
||||
@property
|
||||
def node_data(self) -> NodeDataT:
|
||||
"""Typed access to this node's configuration data."""
|
||||
return self._node_data
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
|||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
|
||||
|
|
@ -22,32 +21,11 @@ from .exc import (
|
|||
)
|
||||
|
||||
|
||||
class CodeNode(Node):
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -20,9 +20,8 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
|
|||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
|
|
@ -38,7 +37,7 @@ from .entities import DatasourceNodeData
|
|||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(Node):
|
||||
class DatasourceNode(Node[DatasourceNodeData]):
|
||||
"""
|
||||
Datasource Node
|
||||
"""
|
||||
|
|
@ -47,27 +46,6 @@ class DatasourceNode(Node):
|
|||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DatasourceNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
|
|
|
|||
|
|
@ -25,9 +25,8 @@ from core.file import File, FileTransferMethod, file_manager
|
|||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
|
|
@ -36,7 +35,7 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(Node):
|
||||
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
"""
|
||||
Extracts text content from various file types.
|
||||
Supports plain text, PDF, and DOC/DOCX files.
|
||||
|
|
@ -46,27 +45,6 @@ class DocumentExtractorNode(Node):
|
|||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -1,41 +1,16 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
|
||||
|
||||
class EndNode(Node):
|
||||
class EndNode(Node[EndNodeData]):
|
||||
node_type = NodeType.END
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from configs import dify_config
|
|||
from core.file import File, FileTransferMethod
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from factories import file_factory
|
||||
|
|
@ -31,32 +31,11 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpRequestNode(Node):
|
||||
class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -2,15 +2,14 @@ from collections.abc import Mapping
|
|||
from typing import Any
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
|
||||
|
||||
class HumanInputNode(Node):
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
node_type = NodeType.HUMAN_INPUT
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
|
|
@ -28,31 +27,10 @@ class HumanInputNode(Node):
|
|||
|
||||
_node_data: HumanInputNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = HumanInputNodeData(**data)
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
|
|
|
|||
|
|
@ -3,9 +3,8 @@ from typing import Any, Literal
|
|||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -13,33 +12,12 @@ from core.workflow.utils.condition.entities import Condition
|
|||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class IfElseNode(Node):
|
||||
class IfElseNode(Node[IfElseNodeData]):
|
||||
node_type = NodeType.IF_ELSE
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
|
|||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
|
|
@ -36,7 +35,6 @@ from core.workflow.node_events import (
|
|||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -60,7 +58,7 @@ logger = logging.getLogger(__name__)
|
|||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||
|
||||
|
||||
class IterationNode(LLMUsageTrackingMixin, Node):
|
||||
class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
|
@ -69,27 +67,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
|||
execution_type = NodeExecutionType.CONTAINER
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
|
||||
|
||||
class IterationStartNode(Node):
|
||||
class IterationStartNode(Node[IterationStartNodeData]):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
|
@ -17,27 +13,6 @@ class IterationStartNode(Node):
|
|||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -35,32 +34,11 @@ default_retrieval_model = {
|
|||
}
|
||||
|
||||
|
||||
class KnowledgeIndexNode(Node):
|
||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
_node_data: KnowledgeIndexNodeData
|
||||
node_type = NodeType.KNOWLEDGE_INDEX
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = KnowledgeIndexNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
|
|
|||
|
|
@ -30,14 +30,12 @@ from core.variables import (
|
|||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
|
|
@ -82,7 +80,7 @@ default_retrieval_model = {
|
|||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
|
|
@ -118,27 +116,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
|||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
from collections.abc import Callable, Mapping, Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
|
|
@ -35,32 +34,11 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
|||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(Node):
|
||||
class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||
node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ from core.variables import (
|
|||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
|
|
@ -69,7 +68,7 @@ from core.workflow.node_events import (
|
|||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -100,7 +99,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(Node):
|
||||
class LLMNode(Node[LLMNodeData]):
|
||||
node_type = NodeType.LLM
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
|
@ -139,27 +138,6 @@ class LLMNode(Node):
|
|||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
|
||||
|
||||
class LoopEndNode(Node):
|
||||
class LoopEndNode(Node[LoopEndNodeData]):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
|
@ -17,27 +13,6 @@ class LoopEndNode(Node):
|
|||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
|||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import Segment, SegmentType
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
|
|
@ -29,7 +28,6 @@ from core.workflow.node_events import (
|
|||
StreamCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
|
@ -42,7 +40,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(LLMUsageTrackingMixin, Node):
|
||||
class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
|
@ -51,27 +49,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
|
|||
_node_data: LoopNodeData
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
|
||||
|
||||
class LoopStartNode(Node):
|
||||
class LoopStartNode(Node[LoopStartNodeData]):
|
||||
"""
|
||||
Loop Start Node.
|
||||
"""
|
||||
|
|
@ -17,27 +13,6 @@ class LoopStartNode(Node):
|
|||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -69,17 +69,9 @@ class DifyNodeFactory(NodeFactory):
|
|||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
|
||||
# Create node instance
|
||||
node_instance = node_class(
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node with provided data
|
||||
node_data = node_config.get("data", {})
|
||||
if not is_str_dict(node_data):
|
||||
raise ValueError(f"Node {node_id} missing data information")
|
||||
node_instance.init_node_data(node_data)
|
||||
|
||||
return node_instance
|
||||
|
|
|
|||
|
|
@ -27,10 +27,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
|||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
|
@ -84,7 +83,7 @@ def extract_json(text):
|
|||
return None
|
||||
|
||||
|
||||
class ParameterExtractorNode(Node):
|
||||
class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
|
|
@ -93,27 +92,6 @@ class ParameterExtractorNode(Node):
|
|||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -13,14 +13,13 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||
|
|
@ -44,7 +43,7 @@ if TYPE_CHECKING:
|
|||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class QuestionClassifierNode(Node):
|
||||
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
|
|
@ -78,27 +77,6 @@ class QuestionClassifierNode(Node):
|
|||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -1,41 +1,16 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
class StartNode(Node):
|
||||
class StartNode(Node[StartNodeData]):
|
||||
node_type = NodeType.START
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -3,41 +3,19 @@ from typing import Any
|
|||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
|
||||
|
||||
class TemplateTransformNode(Node):
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -16,14 +16,12 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
|
|||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -42,7 +40,7 @@ if TYPE_CHECKING:
|
|||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node):
|
||||
class ToolNode(Node[ToolNodeData]):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
|
@ -51,9 +49,6 @@ class ToolNode(Node):
|
|||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -498,24 +493,6 @@ class ToolNode(Node):
|
|||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
|
|
|||
|
|
@ -1,43 +1,18 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import TriggerEventNodeData
|
||||
|
||||
|
||||
class TriggerEventNode(Node):
|
||||
class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
node_type = NodeType.TRIGGER_PLUGIN
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: TriggerEventNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = TriggerEventNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,42 +1,17 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
|
||||
|
||||
|
||||
class TriggerScheduleNode(Node):
|
||||
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
|
||||
node_type = NodeType.TRIGGER_SCHEDULE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: TriggerScheduleNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = TriggerScheduleNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -3,41 +3,17 @@ from typing import Any
|
|||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.enums import NodeExecutionType, NodeType
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import ContentType, WebhookData
|
||||
|
||||
|
||||
class TriggerWebhookNode(Node):
|
||||
class TriggerWebhookNode(Node[WebhookData]):
|
||||
node_type = NodeType.TRIGGER_WEBHOOK
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: WebhookData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = WebhookData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,40 +1,17 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
|
||||
|
||||
class VariableAggregatorNode(Node):
|
||||
class VariableAggregatorNode(Node[VariableAssignerNodeData]):
|
||||
node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ from core.variables import SegmentType, Variable
|
|||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
|
@ -22,33 +21,12 @@ if TYPE_CHECKING:
|
|||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(Node):
|
||||
class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
|
|
|
|||
|
|
@ -7,9 +7,8 @@ from core.variables import SegmentType, Variable
|
|||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
|
|
@ -51,32 +50,11 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
|||
mapping[key] = selector
|
||||
|
||||
|
||||
class VariableAssignerNode(Node):
|
||||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
|
|
|
|||
|
|
@ -159,7 +159,6 @@ class WorkflowEntry:
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(node_config_data)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
|
|
@ -303,7 +302,6 @@ class WorkflowEntry:
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
|
|
|
|||
|
|
@ -69,10 +69,6 @@ def init_code_node(code_config: dict):
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in code_config:
|
||||
node.init_node_data(code_config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -65,10 +65,6 @@ def init_http_node(config: dict):
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
|
@ -709,10 +705,6 @@ def test_nested_object_variable_selector(setup_http_mock):
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in graph_config["nodes"][1]:
|
||||
node.init_node_data(graph_config["nodes"][1]["data"])
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
|
|
|||
|
|
@ -82,10 +82,6 @@ def init_llm_node(config: dict) -> LLMNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -85,7 +85,6 @@ def init_parameter_extractor_node(config: dict):
|
|||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,6 @@ def test_execute_code(setup_code_executor_mock):
|
|||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
|
|
|||
|
|
@ -62,7 +62,6 @@ def init_tool_node(config: dict):
|
|||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -12,14 +11,19 @@ from core.workflow.entities import GraphInitParams
|
|||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.validation import GraphValidationError
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class _TestNode(Node):
|
||||
class _TestNodeData(BaseNodeData):
|
||||
type: NodeType | str | None = None
|
||||
execution_type: NodeExecutionType | str | None = None
|
||||
|
||||
|
||||
class _TestNode(Node[_TestNodeData]):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.EXECUTABLE
|
||||
|
||||
|
|
@ -41,31 +45,8 @@ class _TestNode(Node):
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
data = config.get("data", {})
|
||||
if isinstance(data, Mapping):
|
||||
execution_type = data.get("execution_type")
|
||||
if isinstance(execution_type, str):
|
||||
self.execution_type = NodeExecutionType(execution_type)
|
||||
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
|
||||
self.data: dict[str, object] = {}
|
||||
|
||||
def init_node_data(self, data: Mapping[str, object]) -> None:
|
||||
title = str(data.get("title", self.id))
|
||||
desc = data.get("description")
|
||||
error_strategy_value = data.get("error_strategy")
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
if isinstance(error_strategy_value, ErrorStrategy):
|
||||
error_strategy = error_strategy_value
|
||||
elif isinstance(error_strategy_value, str):
|
||||
error_strategy = ErrorStrategy(error_strategy_value)
|
||||
self._base_node_data = BaseNodeData(
|
||||
title=title,
|
||||
desc=str(desc) if desc is not None else None,
|
||||
error_strategy=error_strategy,
|
||||
)
|
||||
self.data = dict(data)
|
||||
|
||||
node_type_value = data.get("type")
|
||||
node_type_value = self.data.get("type")
|
||||
if isinstance(node_type_value, NodeType):
|
||||
self.node_type = node_type_value
|
||||
elif isinstance(node_type_value, str):
|
||||
|
|
@ -77,23 +58,19 @@ class _TestNode(Node):
|
|||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._base_node_data.error_strategy
|
||||
def post_init(self) -> None:
|
||||
super().post_init()
|
||||
self._maybe_override_execution_type()
|
||||
self.data = dict(self.node_data.model_dump())
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._base_node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._base_node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._base_node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._base_node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._base_node_data
|
||||
def _maybe_override_execution_type(self) -> None:
|
||||
execution_type_value = self.node_data.execution_type
|
||||
if execution_type_value is None:
|
||||
return
|
||||
if isinstance(execution_type_value, NodeExecutionType):
|
||||
self.execution_type = execution_type_value
|
||||
else:
|
||||
self.execution_type = NodeExecutionType(execution_type_value)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
|
|
@ -109,7 +86,6 @@ class _SimpleNodeFactory:
|
|||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(node_config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def test_abort_command():
|
|||
# Create mock nodes with required attributes - using shared runtime state
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start"},
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
|
|
@ -45,7 +45,6 @@ def test_abort_command():
|
|||
),
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
)
|
||||
start_node.init_node_data({"title": "start", "variables": []})
|
||||
mock_graph.nodes["start"] = start_node
|
||||
|
||||
# Mock graph methods
|
||||
|
|
@ -142,7 +141,7 @@ def test_pause_command():
|
|||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start"},
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
|
|
@ -155,7 +154,6 @@ def test_pause_command():
|
|||
),
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
)
|
||||
start_node.init_node_data({"title": "start", "variables": []})
|
||||
mock_graph.nodes["start"] = start_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
|
|
|
|||
|
|
@ -63,7 +63,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
|
|
@ -88,7 +87,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
|
|
@ -105,7 +103,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
|
|
@ -125,7 +122,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
|
|
@ -142,7 +138,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
|
|
|
|||
|
|
@ -62,7 +62,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
|
|
@ -87,7 +86,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
|
||||
|
|
@ -104,7 +102,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||
|
||||
|
|
@ -123,7 +120,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_node.init_node_data(end_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
|
|
|
|||
|
|
@ -62,7 +62,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
|
|
@ -87,7 +86,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
|
|
@ -118,7 +116,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if_else_node.init_node_data(if_else_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
|
|
@ -138,7 +135,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
|
|
@ -155,7 +151,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
|
|
|
|||
|
|
@ -111,9 +111,6 @@ class MockNodeFactory(DifyNodeFactory):
|
|||
mock_config=self.mock_config,
|
||||
)
|
||||
|
||||
# Initialize node with provided data
|
||||
mock_instance.init_node_data(node_data)
|
||||
|
||||
return mock_instance
|
||||
|
||||
# For non-mocked node types, use parent implementation
|
||||
|
|
|
|||
|
|
@ -142,6 +142,8 @@ def test_mock_loop_node_preserves_config():
|
|||
"start_node_id": "node1",
|
||||
"loop_variables": [],
|
||||
"outputs": {},
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,6 @@ class TestMockTemplateTransformNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
|
@ -125,7 +124,6 @@ class TestMockTemplateTransformNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
|
@ -184,7 +182,6 @@ class TestMockTemplateTransformNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
|
@ -246,7 +243,6 @@ class TestMockTemplateTransformNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
|
@ -311,7 +307,6 @@ class TestMockCodeNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
|
@ -376,7 +371,6 @@ class TestMockCodeNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
|
@ -445,7 +439,6 @@ class TestMockCodeNode:
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
|
|
|||
|
|
@ -83,9 +83,6 @@ def test_execute_answer():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
# Ensures that all node classes are imported.
|
||||
|
|
@ -7,6 +10,12 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
|||
_ = NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
||||
class _TestNodeData(BaseNodeData):
|
||||
"""Test node data for unit tests."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _get_all_subclasses(root: type[Node]) -> list[type[Node]]:
|
||||
subclasses = []
|
||||
queue = [root]
|
||||
|
|
@ -34,3 +43,79 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
|||
node_type_and_version = (node_type, node_version)
|
||||
assert node_type_and_version not in type_version_set
|
||||
type_version_set.add(node_type_and_version)
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_extracts_type():
|
||||
"""When a class inherits from Node[T], it should extract T."""
|
||||
|
||||
class _ConcreteNode(Node[_TestNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
return "1"
|
||||
|
||||
result = _ConcreteNode._extract_node_data_type_from_generic()
|
||||
|
||||
assert result is _TestNodeData
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_returns_none_for_base_node():
|
||||
"""The base Node class itself should return None (no generic parameter)."""
|
||||
result = Node._extract_node_data_type_from_generic()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_raises_for_non_base_node_data():
|
||||
"""When generic parameter is not a BaseNodeData subtype, should raise TypeError."""
|
||||
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
|
||||
|
||||
class _InvalidNode(Node[str]): # type: ignore[type-arg]
|
||||
pass
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_raises_for_non_type():
|
||||
"""When generic parameter is not a concrete type, should raise TypeError."""
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
|
||||
|
||||
class _InvalidNode(Node[T]): # type: ignore[type-arg]
|
||||
pass
|
||||
|
||||
|
||||
def test_init_subclass_raises_without_generic_or_explicit_type():
|
||||
"""A subclass must either use Node[T] or explicitly set _node_data_type."""
|
||||
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
|
||||
|
||||
class _InvalidNode(Node):
|
||||
pass
|
||||
|
||||
|
||||
def test_init_subclass_rejects_explicit_node_data_type_without_generic():
|
||||
"""Setting _node_data_type explicitly cannot bypass the Node[T] requirement."""
|
||||
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
|
||||
|
||||
class _ExplicitNode(Node):
|
||||
_node_data_type = _TestNodeData
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
return "1"
|
||||
|
||||
|
||||
def test_init_subclass_sets_node_data_type_from_generic():
|
||||
"""Verify that __init_subclass__ sets _node_data_type from the generic parameter."""
|
||||
|
||||
class _AutoNode(Node[_TestNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
return "1"
|
||||
|
||||
assert _AutoNode._node_data_type is _TestNodeData
|
||||
|
|
|
|||
|
|
@ -111,8 +111,6 @@ def llm_node(
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
|
@ -498,8 +496,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
from collections.abc import Mapping
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class _SampleNodeData(BaseNodeData):
|
||||
foo: str
|
||||
|
||||
|
||||
class _SampleNode(Node[_SampleNodeData]):
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "sample-test"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def test_node_hydrates_data_during_initialization():
|
||||
graph_config: dict[str, object] = {}
|
||||
init_params, runtime_state = _build_context(graph_config)
|
||||
|
||||
node = _SampleNode(
|
||||
id="node-1",
|
||||
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert node.node_data.foo == "bar"
|
||||
assert node.title == "Sample"
|
||||
|
||||
|
||||
def test_missing_generic_argument_raises_type_error():
|
||||
graph_config: dict[str, object] = {}
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class _InvalidNode(Node): # type: ignore[type-abstract]
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
|
@ -50,8 +50,6 @@ def document_extractor_node(graph_init_params):
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -114,9 +114,6 @@ def test_execute_if_else_result_true():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
|
@ -187,9 +184,6 @@ def test_execute_if_else_result_false():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
|
@ -252,9 +246,6 @@ def test_array_file_contains_file_name():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
||||
value=[
|
||||
File(
|
||||
|
|
@ -347,7 +338,6 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
|
@ -417,7 +407,6 @@ def test_execute_if_else_boolean_false_conditions():
|
|||
"data": node_data,
|
||||
},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
|
@ -487,7 +476,6 @@ def test_execute_if_else_boolean_cases_structure():
|
|||
graph_runtime_state=graph_runtime_state,
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
|
|
|||
|
|
@ -57,8 +57,6 @@ def list_operator_node():
|
|||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.variable_pool = MagicMock()
|
||||
return node
|
||||
|
|
|
|||
|
|
@ -73,7 +73,6 @@ def tool_node(monkeypatch) -> "ToolNode":
|
|||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -101,9 +101,6 @@ def test_overwrite_string_variable():
|
|||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
|
|
@ -203,9 +200,6 @@ def test_append_variable_to_array():
|
|||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
|
|
@ -296,9 +290,6 @@ def test_clear_array():
|
|||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
|
|
|
|||
|
|
@ -139,11 +139,6 @@ def test_remove_first_from_array():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
|
||||
# Run the node
|
||||
result = list(node.run())
|
||||
|
||||
|
|
@ -228,10 +223,6 @@ def test_remove_last_from_array():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
|
|
@ -313,10 +304,6 @@ def test_remove_first_from_empty_array():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
|
|
@ -398,10 +385,6 @@ def test_remove_last_from_empty_array():
|
|||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
|
|||
),
|
||||
)
|
||||
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue