mirror of https://github.com/langgenius/dify.git
Refactor: centralize node data hydration (#27771)
This commit is contained in:
parent
1b733abe82
commit
13bf6547ee
|
|
@ -152,10 +152,5 @@ class CodeExecutor:
|
||||||
raise CodeExecutionError(f"Unsupported language {language}")
|
raise CodeExecutionError(f"Unsupported language {language}")
|
||||||
|
|
||||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||||
|
response = cls.execute_code(language, preload, runner)
|
||||||
try:
|
|
||||||
response = cls.execute_code(language, preload, runner)
|
|
||||||
except CodeExecutionError as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return template_transformer.transform_response(response)
|
return template_transformer.transform_response(response)
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,6 @@ from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
|
||||||
NodeType,
|
NodeType,
|
||||||
SystemVariableKey,
|
SystemVariableKey,
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
|
@ -40,7 +39,6 @@ from core.workflow.node_events import (
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
@ -66,7 +64,7 @@ if TYPE_CHECKING:
|
||||||
from core.plugin.entities.request import InvokeCredentials
|
from core.plugin.entities.request import InvokeCredentials
|
||||||
|
|
||||||
|
|
||||||
class AgentNode(Node):
|
class AgentNode(Node[AgentNodeData]):
|
||||||
"""
|
"""
|
||||||
Agent Node
|
Agent Node
|
||||||
"""
|
"""
|
||||||
|
|
@ -74,27 +72,6 @@ class AgentNode(Node):
|
||||||
node_type = NodeType.AGENT
|
node_type = NodeType.AGENT
|
||||||
_node_data: AgentNodeData
|
_node_data: AgentNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = AgentNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -2,42 +2,20 @@ from collections.abc import Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.variables import ArrayFileSegment, FileSegment, Segment
|
from core.variables import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.answer.entities import AnswerNodeData
|
from core.workflow.nodes.answer.entities import AnswerNodeData
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.template import Template
|
from core.workflow.nodes.base.template import Template
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
|
|
||||||
|
|
||||||
class AnswerNode(Node):
|
class AnswerNode(Node[AnswerNodeData]):
|
||||||
node_type = NodeType.ANSWER
|
node_type = NodeType.ANSWER
|
||||||
execution_type = NodeExecutionType.RESPONSE
|
execution_type = NodeExecutionType.RESPONSE
|
||||||
|
|
||||||
_node_data: AnswerNodeData
|
_node_data: AnswerNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = AnswerNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
|
@ -49,12 +49,121 @@ from models.enums import UserFrom
|
||||||
|
|
||||||
from .entities import BaseNodeData, RetryConfig
|
from .entities import BaseNodeData, RetryConfig
|
||||||
|
|
||||||
|
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
class Node(Generic[NodeDataT]):
|
||||||
node_type: ClassVar["NodeType"]
|
node_type: ClassVar["NodeType"]
|
||||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||||
|
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Automatically extract and validate the node data type from the generic parameter.
|
||||||
|
|
||||||
|
When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method:
|
||||||
|
1. Inspects `__orig_bases__` to find the `Node[T]` parameterization
|
||||||
|
2. Extracts `T` (e.g., `MyNodeData`) from the generic argument
|
||||||
|
3. Validates that `T` is a proper `BaseNodeData` subclass
|
||||||
|
4. Stores it in `_node_data_type` for automatic hydration in `__init__`
|
||||||
|
|
||||||
|
This eliminates the need for subclasses to manually implement boilerplate
|
||||||
|
accessor methods like `_get_title()`, `_get_error_strategy()`, etc.
|
||||||
|
|
||||||
|
How it works:
|
||||||
|
::
|
||||||
|
|
||||||
|
class CodeNode(Node[CodeNodeData]):
|
||||||
|
│ │
|
||||||
|
│ └─────────────────────────────────┐
|
||||||
|
│ │
|
||||||
|
▼ ▼
|
||||||
|
┌─────────────────────────────┐ ┌─────────────────────────────────┐
|
||||||
|
│ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │
|
||||||
|
│ Node[CodeNodeData], │ │ title: str │
|
||||||
|
│ ) │ │ desc: str | None │
|
||||||
|
└──────────────┬──────────────┘ │ ... │
|
||||||
|
│ └─────────────────────────────────┘
|
||||||
|
▼ ▲
|
||||||
|
┌─────────────────────────────┐ │
|
||||||
|
│ get_origin(base) -> Node │ │
|
||||||
|
│ get_args(base) -> ( │ │
|
||||||
|
│ CodeNodeData, │ ──────────────────────┘
|
||||||
|
│ ) │
|
||||||
|
└──────────────┬──────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────┐
|
||||||
|
│ Validate: │
|
||||||
|
│ - Is it a type? │
|
||||||
|
│ - Is it a BaseNodeData │
|
||||||
|
│ subclass? │
|
||||||
|
└──────────────┬──────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────┐
|
||||||
|
│ cls._node_data_type = │
|
||||||
|
│ CodeNodeData │
|
||||||
|
└─────────────────────────────┘
|
||||||
|
|
||||||
|
Later, in __init__:
|
||||||
|
::
|
||||||
|
|
||||||
|
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
CodeNodeData instance
|
||||||
|
(stored in self._node_data)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||||
|
node_type = NodeType.CODE
|
||||||
|
# No need to implement _get_title, _get_error_strategy, etc.
|
||||||
|
"""
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
|
if cls is Node:
|
||||||
|
return
|
||||||
|
|
||||||
|
node_data_type = cls._extract_node_data_type_from_generic()
|
||||||
|
|
||||||
|
if node_data_type is None:
|
||||||
|
raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype")
|
||||||
|
|
||||||
|
cls._node_data_type = node_data_type
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||||
|
"""
|
||||||
|
Extract the node data type from the generic parameter `Node[T]`.
|
||||||
|
|
||||||
|
Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The extracted BaseNodeData subtype, or None if not found.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the generic argument is invalid (not exactly one argument,
|
||||||
|
or not a BaseNodeData subtype).
|
||||||
|
"""
|
||||||
|
# __orig_bases__ contains the original generic bases before type erasure.
|
||||||
|
# For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`.
|
||||||
|
for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined]
|
||||||
|
origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]`
|
||||||
|
if origin is Node:
|
||||||
|
args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]`
|
||||||
|
if len(args) != 1:
|
||||||
|
raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument")
|
||||||
|
|
||||||
|
candidate = args[0]
|
||||||
|
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
|
||||||
|
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
|
||||||
|
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -63,6 +172,7 @@ class Node:
|
||||||
graph_init_params: "GraphInitParams",
|
graph_init_params: "GraphInitParams",
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self._graph_init_params = graph_init_params
|
||||||
self.id = id
|
self.id = id
|
||||||
self.tenant_id = graph_init_params.tenant_id
|
self.tenant_id = graph_init_params.tenant_id
|
||||||
self.app_id = graph_init_params.app_id
|
self.app_id = graph_init_params.app_id
|
||||||
|
|
@ -83,8 +193,24 @@ class Node:
|
||||||
self._node_execution_id: str = ""
|
self._node_execution_id: str = ""
|
||||||
self._start_at = naive_utc_now()
|
self._start_at = naive_utc_now()
|
||||||
|
|
||||||
@abstractmethod
|
raw_node_data = config.get("data") or {}
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
if not isinstance(raw_node_data, Mapping):
|
||||||
|
raise ValueError("Node config data must be a mapping.")
|
||||||
|
|
||||||
|
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def post_init(self) -> None:
|
||||||
|
"""Optional hook for subclasses requiring extra initialization."""
|
||||||
|
return
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph_init_params(self) -> "GraphInitParams":
|
||||||
|
return self._graph_init_params
|
||||||
|
|
||||||
|
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||||
|
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||||
|
|
@ -273,38 +399,29 @@ class Node:
|
||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Abstract methods that subclasses must implement to provide access
|
|
||||||
# to BaseNodeData properties in a type-safe way
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||||
"""Get the error strategy for this node."""
|
"""Get the error strategy for this node."""
|
||||||
...
|
return self._node_data.error_strategy
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
def _get_retry_config(self) -> RetryConfig:
|
||||||
"""Get the retry configuration for this node."""
|
"""Get the retry configuration for this node."""
|
||||||
...
|
return self._node_data.retry_config
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _get_title(self) -> str:
|
def _get_title(self) -> str:
|
||||||
"""Get the node title."""
|
"""Get the node title."""
|
||||||
...
|
return self._node_data.title
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _get_description(self) -> str | None:
|
def _get_description(self) -> str | None:
|
||||||
"""Get the node description."""
|
"""Get the node description."""
|
||||||
...
|
return self._node_data.desc
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||||
"""Get the default values dictionary for this node."""
|
"""Get the default values dictionary for this node."""
|
||||||
...
|
return self._node_data.default_value_dict
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
def get_base_node_data(self) -> BaseNodeData:
|
||||||
"""Get the BaseNodeData object for this node."""
|
"""Get the BaseNodeData object for this node."""
|
||||||
...
|
return self._node_data
|
||||||
|
|
||||||
# Public interface properties that delegate to abstract methods
|
# Public interface properties that delegate to abstract methods
|
||||||
@property
|
@property
|
||||||
|
|
@ -332,6 +449,11 @@ class Node:
|
||||||
"""Get the default values dictionary for this node."""
|
"""Get the default values dictionary for this node."""
|
||||||
return self._get_default_value_dict()
|
return self._get_default_value_dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_data(self) -> NodeDataT:
|
||||||
|
"""Typed access to this node's configuration data."""
|
||||||
|
return self._node_data
|
||||||
|
|
||||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||||
match result.status:
|
match result.status:
|
||||||
case WorkflowNodeExecutionStatus.FAILED:
|
case WorkflowNodeExecutionStatus.FAILED:
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,8 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||||
from core.variables.segments import ArrayFileSegment
|
from core.variables.segments import ArrayFileSegment
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.code.entities import CodeNodeData
|
from core.workflow.nodes.code.entities import CodeNodeData
|
||||||
|
|
||||||
|
|
@ -22,32 +21,11 @@ from .exc import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CodeNode(Node):
|
class CodeNode(Node[CodeNodeData]):
|
||||||
node_type = NodeType.CODE
|
node_type = NodeType.CODE
|
||||||
|
|
||||||
_node_data: CodeNodeData
|
_node_data: CodeNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = CodeNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,8 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.variables.segments import ArrayAnySegment
|
from core.variables.segments import ArrayAnySegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from core.workflow.nodes.tool.exc import ToolFileError
|
from core.workflow.nodes.tool.exc import ToolFileError
|
||||||
|
|
@ -38,7 +37,7 @@ from .entities import DatasourceNodeData
|
||||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||||
|
|
||||||
|
|
||||||
class DatasourceNode(Node):
|
class DatasourceNode(Node[DatasourceNodeData]):
|
||||||
"""
|
"""
|
||||||
Datasource Node
|
Datasource Node
|
||||||
"""
|
"""
|
||||||
|
|
@ -47,27 +46,6 @@ class DatasourceNode(Node):
|
||||||
node_type = NodeType.DATASOURCE
|
node_type = NodeType.DATASOURCE
|
||||||
execution_type = NodeExecutionType.ROOT
|
execution_type = NodeExecutionType.ROOT
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
||||||
self._node_data = DatasourceNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
def _run(self) -> Generator:
|
def _run(self) -> Generator:
|
||||||
"""
|
"""
|
||||||
Run the datasource node
|
Run the datasource node
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,8 @@ from core.file import File, FileTransferMethod, file_manager
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from core.variables import ArrayFileSegment
|
from core.variables import ArrayFileSegment
|
||||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
from .entities import DocumentExtractorNodeData
|
from .entities import DocumentExtractorNodeData
|
||||||
|
|
@ -36,7 +35,7 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DocumentExtractorNode(Node):
|
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||||
"""
|
"""
|
||||||
Extracts text content from various file types.
|
Extracts text content from various file types.
|
||||||
Supports plain text, PDF, and DOC/DOCX files.
|
Supports plain text, PDF, and DOC/DOCX files.
|
||||||
|
|
@ -46,27 +45,6 @@ class DocumentExtractorNode(Node):
|
||||||
|
|
||||||
_node_data: DocumentExtractorNodeData
|
_node_data: DocumentExtractorNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,16 @@
|
||||||
from collections.abc import Mapping
|
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.template import Template
|
from core.workflow.nodes.base.template import Template
|
||||||
from core.workflow.nodes.end.entities import EndNodeData
|
from core.workflow.nodes.end.entities import EndNodeData
|
||||||
|
|
||||||
|
|
||||||
class EndNode(Node):
|
class EndNode(Node[EndNodeData]):
|
||||||
node_type = NodeType.END
|
node_type = NodeType.END
|
||||||
execution_type = NodeExecutionType.RESPONSE
|
execution_type = NodeExecutionType.RESPONSE
|
||||||
|
|
||||||
_node_data: EndNodeData
|
_node_data: EndNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = EndNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@ from configs import dify_config
|
||||||
from core.file import File, FileTransferMethod
|
from core.file import File, FileTransferMethod
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from core.variables.segments import ArrayFileSegment
|
from core.variables.segments import ArrayFileSegment
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base import variable_template_parser
|
from core.workflow.nodes.base import variable_template_parser
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
from core.workflow.nodes.base.entities import VariableSelector
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.http_request.executor import Executor
|
from core.workflow.nodes.http_request.executor import Executor
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
|
|
@ -31,32 +31,11 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestNode(Node):
|
class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||||
node_type = NodeType.HTTP_REQUEST
|
node_type = NodeType.HTTP_REQUEST
|
||||||
|
|
||||||
_node_data: HttpRequestNodeData
|
_node_data: HttpRequestNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -2,15 +2,14 @@ from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
from .entities import HumanInputNodeData
|
from .entities import HumanInputNodeData
|
||||||
|
|
||||||
|
|
||||||
class HumanInputNode(Node):
|
class HumanInputNode(Node[HumanInputNodeData]):
|
||||||
node_type = NodeType.HUMAN_INPUT
|
node_type = NodeType.HUMAN_INPUT
|
||||||
execution_type = NodeExecutionType.BRANCH
|
execution_type = NodeExecutionType.BRANCH
|
||||||
|
|
||||||
|
|
@ -28,31 +27,10 @@ class HumanInputNode(Node):
|
||||||
|
|
||||||
_node_data: HumanInputNodeData
|
_node_data: HumanInputNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
||||||
self._node_data = HumanInputNodeData(**data)
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def _run(self): # type: ignore[override]
|
def _run(self): # type: ignore[override]
|
||||||
if self._is_completion_ready():
|
if self._is_completion_ready():
|
||||||
branch_handle = self._resolve_branch_selection()
|
branch_handle = self._resolve_branch_selection()
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,8 @@ from typing import Any, Literal
|
||||||
|
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
@ -13,33 +12,12 @@ from core.workflow.utils.condition.entities import Condition
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
|
|
||||||
|
|
||||||
class IfElseNode(Node):
|
class IfElseNode(Node[IfElseNodeData]):
|
||||||
node_type = NodeType.IF_ELSE
|
node_type = NodeType.IF_ELSE
|
||||||
execution_type = NodeExecutionType.BRANCH
|
execution_type = NodeExecutionType.BRANCH
|
||||||
|
|
||||||
_node_data: IfElseNodeData
|
_node_data: IfElseNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = IfElseNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||||
from core.variables.variables import VariableUnion
|
from core.variables.variables import VariableUnion
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
|
||||||
NodeExecutionType,
|
NodeExecutionType,
|
||||||
NodeType,
|
NodeType,
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
|
@ -36,7 +35,6 @@ from core.workflow.node_events import (
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
@ -60,7 +58,7 @@ logger = logging.getLogger(__name__)
|
||||||
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
|
||||||
|
|
||||||
|
|
||||||
class IterationNode(LLMUsageTrackingMixin, Node):
|
class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||||
"""
|
"""
|
||||||
Iteration Node.
|
Iteration Node.
|
||||||
"""
|
"""
|
||||||
|
|
@ -69,27 +67,6 @@ class IterationNode(LLMUsageTrackingMixin, Node):
|
||||||
execution_type = NodeExecutionType.CONTAINER
|
execution_type = NodeExecutionType.CONTAINER
|
||||||
_node_data: IterationNodeData
|
_node_data: IterationNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = IterationNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,10 @@
|
||||||
from collections.abc import Mapping
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||||
|
|
||||||
|
|
||||||
class IterationStartNode(Node):
|
class IterationStartNode(Node[IterationStartNodeData]):
|
||||||
"""
|
"""
|
||||||
Iteration Start Node.
|
Iteration Start Node.
|
||||||
"""
|
"""
|
||||||
|
|
@ -17,27 +13,6 @@ class IterationStartNode(Node):
|
||||||
|
|
||||||
_node_data: IterationStartNodeData
|
_node_data: IterationStartNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = IterationStartNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.template import Template
|
from core.workflow.nodes.base.template import Template
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
@ -35,32 +34,11 @@ default_retrieval_model = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeIndexNode(Node):
|
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||||
_node_data: KnowledgeIndexNodeData
|
_node_data: KnowledgeIndexNodeData
|
||||||
node_type = NodeType.KNOWLEDGE_INDEX
|
node_type = NodeType.KNOWLEDGE_INDEX
|
||||||
execution_type = NodeExecutionType.RESPONSE
|
execution_type = NodeExecutionType.RESPONSE
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
||||||
self._node_data = KnowledgeIndexNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult: # type: ignore
|
def _run(self) -> NodeRunResult: # type: ignore
|
||||||
node_data = self._node_data
|
node_data = self._node_data
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
|
||||||
|
|
@ -30,14 +30,12 @@ from core.variables import (
|
||||||
from core.variables.segments import ArrayObjectSegment
|
from core.variables.segments import ArrayObjectSegment
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
|
||||||
NodeType,
|
NodeType,
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
)
|
)
|
||||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||||
|
|
@ -82,7 +80,7 @@ default_retrieval_model = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||||
|
|
||||||
_node_data: KnowledgeRetrievalNodeData
|
_node_data: KnowledgeRetrievalNodeData
|
||||||
|
|
@ -118,27 +116,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node):
|
||||||
)
|
)
|
||||||
self._llm_file_saver = llm_file_saver
|
self._llm_file_saver = llm_file_saver
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls):
|
def version(cls):
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
from collections.abc import Callable, Mapping, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, TypeAlias, TypeVar
|
from typing import Any, TypeAlias, TypeVar
|
||||||
|
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||||
|
|
@ -35,32 +34,11 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class ListOperatorNode(Node):
|
class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||||
node_type = NodeType.LIST_OPERATOR
|
node_type = NodeType.LIST_OPERATOR
|
||||||
|
|
||||||
_node_data: ListOperatorNodeData
|
_node_data: ListOperatorNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = ListOperatorNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,6 @@ from core.variables import (
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
|
||||||
NodeType,
|
NodeType,
|
||||||
SystemVariableKey,
|
SystemVariableKey,
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
|
@ -69,7 +68,7 @@ from core.workflow.node_events import (
|
||||||
StreamChunkEvent,
|
StreamChunkEvent,
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
from core.workflow.nodes.base.entities import VariableSelector
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
@ -100,7 +99,7 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMNode(Node):
|
class LLMNode(Node[LLMNodeData]):
|
||||||
node_type = NodeType.LLM
|
node_type = NodeType.LLM
|
||||||
|
|
||||||
_node_data: LLMNodeData
|
_node_data: LLMNodeData
|
||||||
|
|
@ -139,27 +138,6 @@ class LLMNode(Node):
|
||||||
)
|
)
|
||||||
self._llm_file_saver = llm_file_saver
|
self._llm_file_saver = llm_file_saver
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = LLMNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,10 @@
|
||||||
from collections.abc import Mapping
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||||
|
|
||||||
|
|
||||||
class LoopEndNode(Node):
|
class LoopEndNode(Node[LoopEndNodeData]):
|
||||||
"""
|
"""
|
||||||
Loop End Node.
|
Loop End Node.
|
||||||
"""
|
"""
|
||||||
|
|
@ -17,27 +13,6 @@ class LoopEndNode(Node):
|
||||||
|
|
||||||
_node_data: LoopEndNodeData
|
_node_data: LoopEndNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = LoopEndNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.variables import Segment, SegmentType
|
from core.variables import Segment, SegmentType
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
|
||||||
NodeExecutionType,
|
NodeExecutionType,
|
||||||
NodeType,
|
NodeType,
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
|
|
@ -29,7 +28,6 @@ from core.workflow.node_events import (
|
||||||
StreamCompletedEvent,
|
StreamCompletedEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
from core.workflow.nodes.loop.entities import LoopNodeData, LoopVariableData
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
|
|
@ -42,7 +40,7 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoopNode(LLMUsageTrackingMixin, Node):
|
class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||||
"""
|
"""
|
||||||
Loop Node.
|
Loop Node.
|
||||||
"""
|
"""
|
||||||
|
|
@ -51,27 +49,6 @@ class LoopNode(LLMUsageTrackingMixin, Node):
|
||||||
_node_data: LoopNodeData
|
_node_data: LoopNodeData
|
||||||
execution_type = NodeExecutionType.CONTAINER
|
execution_type = NodeExecutionType.CONTAINER
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = LoopNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,10 @@
|
||||||
from collections.abc import Mapping
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||||
|
|
||||||
|
|
||||||
class LoopStartNode(Node):
|
class LoopStartNode(Node[LoopStartNodeData]):
|
||||||
"""
|
"""
|
||||||
Loop Start Node.
|
Loop Start Node.
|
||||||
"""
|
"""
|
||||||
|
|
@ -17,27 +13,6 @@ class LoopStartNode(Node):
|
||||||
|
|
||||||
_node_data: LoopStartNodeData
|
_node_data: LoopStartNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = LoopStartNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -69,17 +69,9 @@ class DifyNodeFactory(NodeFactory):
|
||||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||||
|
|
||||||
# Create node instance
|
# Create node instance
|
||||||
node_instance = node_class(
|
return node_class(
|
||||||
id=node_id,
|
id=node_id,
|
||||||
config=node_config,
|
config=node_config,
|
||||||
graph_init_params=self.graph_init_params,
|
graph_init_params=self.graph_init_params,
|
||||||
graph_runtime_state=self.graph_runtime_state,
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node with provided data
|
|
||||||
node_data = node_config.get("data", {})
|
|
||||||
if not is_str_dict(node_data):
|
|
||||||
raise ValueError(f"Node {node_id} missing data information")
|
|
||||||
node_instance.init_node_data(node_data)
|
|
||||||
|
|
||||||
return node_instance
|
|
||||||
|
|
|
||||||
|
|
@ -27,10 +27,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.variables.types import ArrayValidation, SegmentType
|
from core.variables.types import ArrayValidation, SegmentType
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base import variable_template_parser
|
from core.workflow.nodes.base import variable_template_parser
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
@ -84,7 +83,7 @@ def extract_json(text):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ParameterExtractorNode(Node):
|
class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||||
"""
|
"""
|
||||||
Parameter Extractor Node.
|
Parameter Extractor Node.
|
||||||
"""
|
"""
|
||||||
|
|
@ -93,27 +92,6 @@ class ParameterExtractorNode(Node):
|
||||||
|
|
||||||
_node_data: ParameterExtractorNodeData
|
_node_data: ParameterExtractorNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
_model_instance: ModelInstance | None = None
|
_model_instance: ModelInstance | None = None
|
||||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,14 +13,13 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
|
||||||
NodeExecutionType,
|
NodeExecutionType,
|
||||||
NodeType,
|
NodeType,
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
)
|
)
|
||||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
from core.workflow.nodes.base.entities import VariableSelector
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||||
|
|
@ -44,7 +43,7 @@ if TYPE_CHECKING:
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
class QuestionClassifierNode(Node):
|
class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||||
node_type = NodeType.QUESTION_CLASSIFIER
|
node_type = NodeType.QUESTION_CLASSIFIER
|
||||||
execution_type = NodeExecutionType.BRANCH
|
execution_type = NodeExecutionType.BRANCH
|
||||||
|
|
||||||
|
|
@ -78,27 +77,6 @@ class QuestionClassifierNode(Node):
|
||||||
)
|
)
|
||||||
self._llm_file_saver = llm_file_saver
|
self._llm_file_saver = llm_file_saver
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls):
|
def version(cls):
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,16 @@
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.start.entities import StartNodeData
|
from core.workflow.nodes.start.entities import StartNodeData
|
||||||
|
|
||||||
|
|
||||||
class StartNode(Node):
|
class StartNode(Node[StartNodeData]):
|
||||||
node_type = NodeType.START
|
node_type = NodeType.START
|
||||||
execution_type = NodeExecutionType.ROOT
|
execution_type = NodeExecutionType.ROOT
|
||||||
|
|
||||||
_node_data: StartNodeData
|
_node_data: StartNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = StartNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -3,41 +3,19 @@ from typing import Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||||
|
|
||||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||||
|
|
||||||
|
|
||||||
class TemplateTransformNode(Node):
|
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||||
|
|
||||||
_node_data: TemplateTransformNodeData
|
_node_data: TemplateTransformNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,12 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||||
from core.variables.variables import ArrayAnyVariable
|
from core.variables.variables import ArrayAnyVariable
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
ErrorStrategy,
|
|
||||||
NodeType,
|
NodeType,
|
||||||
SystemVariableKey,
|
SystemVariableKey,
|
||||||
WorkflowNodeExecutionMetadataKey,
|
WorkflowNodeExecutionMetadataKey,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
)
|
)
|
||||||
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
@ -42,7 +40,7 @@ if TYPE_CHECKING:
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(Node):
|
class ToolNode(Node[ToolNodeData]):
|
||||||
"""
|
"""
|
||||||
Tool Node
|
Tool Node
|
||||||
"""
|
"""
|
||||||
|
|
@ -51,9 +49,6 @@ class ToolNode(Node):
|
||||||
|
|
||||||
_node_data: ToolNodeData
|
_node_data: ToolNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = ToolNodeData.model_validate(data)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
@ -498,24 +493,6 @@ class ToolNode(Node):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def retry(self) -> bool:
|
def retry(self) -> bool:
|
||||||
return self._node_data.retry_config.retry_enabled
|
return self._node_data.retry_config.retry_enabled
|
||||||
|
|
|
||||||
|
|
@ -1,43 +1,18 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
from core.workflow.enums import NodeExecutionType, NodeType
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
from .entities import TriggerEventNodeData
|
from .entities import TriggerEventNodeData
|
||||||
|
|
||||||
|
|
||||||
class TriggerEventNode(Node):
|
class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||||
node_type = NodeType.TRIGGER_PLUGIN
|
node_type = NodeType.TRIGGER_PLUGIN
|
||||||
execution_type = NodeExecutionType.ROOT
|
execution_type = NodeExecutionType.ROOT
|
||||||
|
|
||||||
_node_data: TriggerEventNodeData
|
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
||||||
self._node_data = TriggerEventNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,42 +1,17 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
from core.workflow.enums import NodeExecutionType, NodeType
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
|
from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData
|
||||||
|
|
||||||
|
|
||||||
class TriggerScheduleNode(Node):
|
class TriggerScheduleNode(Node[TriggerScheduleNodeData]):
|
||||||
node_type = NodeType.TRIGGER_SCHEDULE
|
node_type = NodeType.TRIGGER_SCHEDULE
|
||||||
execution_type = NodeExecutionType.ROOT
|
execution_type = NodeExecutionType.ROOT
|
||||||
|
|
||||||
_node_data: TriggerScheduleNodeData
|
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
||||||
self._node_data = TriggerScheduleNodeData(**data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -3,41 +3,17 @@ from typing import Any
|
||||||
|
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
from core.workflow.enums import NodeExecutionType, NodeType
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
from .entities import ContentType, WebhookData
|
from .entities import ContentType, WebhookData
|
||||||
|
|
||||||
|
|
||||||
class TriggerWebhookNode(Node):
|
class TriggerWebhookNode(Node[WebhookData]):
|
||||||
node_type = NodeType.TRIGGER_WEBHOOK
|
node_type = NodeType.TRIGGER_WEBHOOK
|
||||||
execution_type = NodeExecutionType.ROOT
|
execution_type = NodeExecutionType.ROOT
|
||||||
|
|
||||||
_node_data: WebhookData
|
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
|
||||||
self._node_data = WebhookData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,17 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.variables.segments import Segment
|
from core.variables.segments import Segment
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||||
|
|
||||||
|
|
||||||
class VariableAggregatorNode(Node):
|
class VariableAggregatorNode(Node[VariableAssignerNodeData]):
|
||||||
node_type = NodeType.VARIABLE_AGGREGATOR
|
node_type = NodeType.VARIABLE_AGGREGATOR
|
||||||
|
|
||||||
_node_data: VariableAssignerNodeData
|
_node_data: VariableAssignerNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,8 @@ from core.variables import SegmentType, Variable
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||||
|
|
@ -22,33 +21,12 @@ if TYPE_CHECKING:
|
||||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||||
|
|
||||||
|
|
||||||
class VariableAssignerNode(Node):
|
class VariableAssignerNode(Node[VariableAssignerData]):
|
||||||
node_type = NodeType.VARIABLE_ASSIGNER
|
node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||||
|
|
||||||
_node_data: VariableAssignerData
|
_node_data: VariableAssignerData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = VariableAssignerData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: str,
|
id: str,
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,8 @@ from core.variables import SegmentType, Variable
|
||||||
from core.variables.consts import SELECTORS_LENGTH
|
from core.variables.consts import SELECTORS_LENGTH
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import NodeRunResult
|
from core.workflow.node_events import NodeRunResult
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||||
|
|
@ -51,32 +50,11 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
||||||
mapping[key] = selector
|
mapping[key] = selector
|
||||||
|
|
||||||
|
|
||||||
class VariableAssignerNode(Node):
|
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||||
node_type = NodeType.VARIABLE_ASSIGNER
|
node_type = NodeType.VARIABLE_ASSIGNER
|
||||||
|
|
||||||
_node_data: VariableAssignerNodeData
|
_node_data: VariableAssignerNodeData
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, Any]):
|
|
||||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
|
||||||
return self._node_data.error_strategy
|
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
|
||||||
return self._node_data.retry_config
|
|
||||||
|
|
||||||
def _get_title(self) -> str:
|
|
||||||
return self._node_data.title
|
|
||||||
|
|
||||||
def _get_description(self) -> str | None:
|
|
||||||
return self._node_data.desc
|
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._node_data
|
|
||||||
|
|
||||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if this Variable Assigner node blocks the output of specific variables.
|
Check if this Variable Assigner node blocks the output of specific variables.
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,6 @@ class WorkflowEntry:
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
node.init_node_data(node_config_data)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# variable selector to variable mapping
|
# variable selector to variable mapping
|
||||||
|
|
@ -303,7 +302,6 @@ class WorkflowEntry:
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
node.init_node_data(node_data)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# variable selector to variable mapping
|
# variable selector to variable mapping
|
||||||
|
|
|
||||||
|
|
@ -69,10 +69,6 @@ def init_code_node(code_config: dict):
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
if "data" in code_config:
|
|
||||||
node.init_node_data(code_config["data"])
|
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,10 +65,6 @@ def init_http_node(config: dict):
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
if "data" in config:
|
|
||||||
node.init_node_data(config["data"])
|
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -709,10 +705,6 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
if "data" in graph_config["nodes"][1]:
|
|
||||||
node.init_node_data(graph_config["nodes"][1]["data"])
|
|
||||||
|
|
||||||
result = node._run()
|
result = node._run()
|
||||||
assert result.process_data is not None
|
assert result.process_data is not None
|
||||||
data = result.process_data.get("request", "")
|
data = result.process_data.get("request", "")
|
||||||
|
|
|
||||||
|
|
@ -82,10 +82,6 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
if "data" in config:
|
|
||||||
node.init_node_data(config["data"])
|
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,6 @@ def init_parameter_extractor_node(config: dict):
|
||||||
graph_init_params=init_params,
|
graph_init_params=init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
node.init_node_data(config.get("data", {}))
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,6 @@ def test_execute_code(setup_code_executor_mock):
|
||||||
graph_init_params=init_params,
|
graph_init_params=init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
node.init_node_data(config.get("data", {}))
|
|
||||||
|
|
||||||
# execute node
|
# execute node
|
||||||
result = node._run()
|
result = node._run()
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,6 @@ def init_tool_node(config: dict):
|
||||||
graph_init_params=init_params,
|
graph_init_params=init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
node.init_node_data(config.get("data", {}))
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -12,14 +11,19 @@ from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph.validation import GraphValidationError
|
from core.workflow.graph.validation import GraphValidationError
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
|
|
||||||
|
|
||||||
class _TestNode(Node):
|
class _TestNodeData(BaseNodeData):
|
||||||
|
type: NodeType | str | None = None
|
||||||
|
execution_type: NodeExecutionType | str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _TestNode(Node[_TestNodeData]):
|
||||||
node_type = NodeType.ANSWER
|
node_type = NodeType.ANSWER
|
||||||
execution_type = NodeExecutionType.EXECUTABLE
|
execution_type = NodeExecutionType.EXECUTABLE
|
||||||
|
|
||||||
|
|
@ -41,31 +45,8 @@ class _TestNode(Node):
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
data = config.get("data", {})
|
|
||||||
if isinstance(data, Mapping):
|
|
||||||
execution_type = data.get("execution_type")
|
|
||||||
if isinstance(execution_type, str):
|
|
||||||
self.execution_type = NodeExecutionType(execution_type)
|
|
||||||
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
|
|
||||||
self.data: dict[str, object] = {}
|
|
||||||
|
|
||||||
def init_node_data(self, data: Mapping[str, object]) -> None:
|
node_type_value = self.data.get("type")
|
||||||
title = str(data.get("title", self.id))
|
|
||||||
desc = data.get("description")
|
|
||||||
error_strategy_value = data.get("error_strategy")
|
|
||||||
error_strategy: ErrorStrategy | None = None
|
|
||||||
if isinstance(error_strategy_value, ErrorStrategy):
|
|
||||||
error_strategy = error_strategy_value
|
|
||||||
elif isinstance(error_strategy_value, str):
|
|
||||||
error_strategy = ErrorStrategy(error_strategy_value)
|
|
||||||
self._base_node_data = BaseNodeData(
|
|
||||||
title=title,
|
|
||||||
desc=str(desc) if desc is not None else None,
|
|
||||||
error_strategy=error_strategy,
|
|
||||||
)
|
|
||||||
self.data = dict(data)
|
|
||||||
|
|
||||||
node_type_value = data.get("type")
|
|
||||||
if isinstance(node_type_value, NodeType):
|
if isinstance(node_type_value, NodeType):
|
||||||
self.node_type = node_type_value
|
self.node_type = node_type_value
|
||||||
elif isinstance(node_type_value, str):
|
elif isinstance(node_type_value, str):
|
||||||
|
|
@ -77,23 +58,19 @@ class _TestNode(Node):
|
||||||
def _run(self):
|
def _run(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
def post_init(self) -> None:
|
||||||
return self._base_node_data.error_strategy
|
super().post_init()
|
||||||
|
self._maybe_override_execution_type()
|
||||||
|
self.data = dict(self.node_data.model_dump())
|
||||||
|
|
||||||
def _get_retry_config(self) -> RetryConfig:
|
def _maybe_override_execution_type(self) -> None:
|
||||||
return self._base_node_data.retry_config
|
execution_type_value = self.node_data.execution_type
|
||||||
|
if execution_type_value is None:
|
||||||
def _get_title(self) -> str:
|
return
|
||||||
return self._base_node_data.title
|
if isinstance(execution_type_value, NodeExecutionType):
|
||||||
|
self.execution_type = execution_type_value
|
||||||
def _get_description(self) -> str | None:
|
else:
|
||||||
return self._base_node_data.desc
|
self.execution_type = NodeExecutionType(execution_type_value)
|
||||||
|
|
||||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
|
||||||
return self._base_node_data.default_value_dict
|
|
||||||
|
|
||||||
def get_base_node_data(self) -> BaseNodeData:
|
|
||||||
return self._base_node_data
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
|
|
@ -109,7 +86,6 @@ class _SimpleNodeFactory:
|
||||||
graph_init_params=self.graph_init_params,
|
graph_init_params=self.graph_init_params,
|
||||||
graph_runtime_state=self.graph_runtime_state,
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
)
|
)
|
||||||
node.init_node_data(node_config.get("data", {}))
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ def test_abort_command():
|
||||||
# Create mock nodes with required attributes - using shared runtime state
|
# Create mock nodes with required attributes - using shared runtime state
|
||||||
start_node = StartNode(
|
start_node = StartNode(
|
||||||
id="start",
|
id="start",
|
||||||
config={"id": "start"},
|
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||||
graph_init_params=GraphInitParams(
|
graph_init_params=GraphInitParams(
|
||||||
tenant_id="test_tenant",
|
tenant_id="test_tenant",
|
||||||
app_id="test_app",
|
app_id="test_app",
|
||||||
|
|
@ -45,7 +45,6 @@ def test_abort_command():
|
||||||
),
|
),
|
||||||
graph_runtime_state=shared_runtime_state,
|
graph_runtime_state=shared_runtime_state,
|
||||||
)
|
)
|
||||||
start_node.init_node_data({"title": "start", "variables": []})
|
|
||||||
mock_graph.nodes["start"] = start_node
|
mock_graph.nodes["start"] = start_node
|
||||||
|
|
||||||
# Mock graph methods
|
# Mock graph methods
|
||||||
|
|
@ -142,7 +141,7 @@ def test_pause_command():
|
||||||
|
|
||||||
start_node = StartNode(
|
start_node = StartNode(
|
||||||
id="start",
|
id="start",
|
||||||
config={"id": "start"},
|
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||||
graph_init_params=GraphInitParams(
|
graph_init_params=GraphInitParams(
|
||||||
tenant_id="test_tenant",
|
tenant_id="test_tenant",
|
||||||
app_id="test_app",
|
app_id="test_app",
|
||||||
|
|
@ -155,7 +154,6 @@ def test_pause_command():
|
||||||
),
|
),
|
||||||
graph_runtime_state=shared_runtime_state,
|
graph_runtime_state=shared_runtime_state,
|
||||||
)
|
)
|
||||||
start_node.init_node_data({"title": "start", "variables": []})
|
|
||||||
mock_graph.nodes["start"] = start_node
|
mock_graph.nodes["start"] = start_node
|
||||||
|
|
||||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
start_node.init_node_data(start_config["data"])
|
|
||||||
|
|
||||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||||
llm_data = LLMNodeData(
|
llm_data = LLMNodeData(
|
||||||
|
|
@ -88,7 +87,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
llm_node.init_node_data(llm_config["data"])
|
|
||||||
return llm_node
|
return llm_node
|
||||||
|
|
||||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||||
|
|
@ -105,7 +103,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
human_node.init_node_data(human_config["data"])
|
|
||||||
|
|
||||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||||
|
|
@ -125,7 +122,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
end_primary.init_node_data(end_primary_config["data"])
|
|
||||||
|
|
||||||
end_secondary_data = EndNodeData(
|
end_secondary_data = EndNodeData(
|
||||||
title="End Secondary",
|
title="End Secondary",
|
||||||
|
|
@ -142,7 +138,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
end_secondary.init_node_data(end_secondary_config["data"])
|
|
||||||
|
|
||||||
graph = (
|
graph = (
|
||||||
Graph.new()
|
Graph.new()
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
start_node.init_node_data(start_config["data"])
|
|
||||||
|
|
||||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||||
llm_data = LLMNodeData(
|
llm_data = LLMNodeData(
|
||||||
|
|
@ -87,7 +86,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
llm_node.init_node_data(llm_config["data"])
|
|
||||||
return llm_node
|
return llm_node
|
||||||
|
|
||||||
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
|
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
|
||||||
|
|
@ -104,7 +102,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
human_node.init_node_data(human_config["data"])
|
|
||||||
|
|
||||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||||
|
|
||||||
|
|
@ -123,7 +120,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
end_node.init_node_data(end_config["data"])
|
|
||||||
|
|
||||||
graph = (
|
graph = (
|
||||||
Graph.new()
|
Graph.new()
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
start_node.init_node_data(start_config["data"])
|
|
||||||
|
|
||||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||||
llm_data = LLMNodeData(
|
llm_data = LLMNodeData(
|
||||||
|
|
@ -87,7 +86,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
llm_node.init_node_data(llm_config["data"])
|
|
||||||
return llm_node
|
return llm_node
|
||||||
|
|
||||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||||
|
|
@ -118,7 +116,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
if_else_node.init_node_data(if_else_config["data"])
|
|
||||||
|
|
||||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||||
|
|
@ -138,7 +135,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
end_primary.init_node_data(end_primary_config["data"])
|
|
||||||
|
|
||||||
end_secondary_data = EndNodeData(
|
end_secondary_data = EndNodeData(
|
||||||
title="End Secondary",
|
title="End Secondary",
|
||||||
|
|
@ -155,7 +151,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
end_secondary.init_node_data(end_secondary_config["data"])
|
|
||||||
|
|
||||||
graph = (
|
graph = (
|
||||||
Graph.new()
|
Graph.new()
|
||||||
|
|
|
||||||
|
|
@ -111,9 +111,6 @@ class MockNodeFactory(DifyNodeFactory):
|
||||||
mock_config=self.mock_config,
|
mock_config=self.mock_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node with provided data
|
|
||||||
mock_instance.init_node_data(node_data)
|
|
||||||
|
|
||||||
return mock_instance
|
return mock_instance
|
||||||
|
|
||||||
# For non-mocked node types, use parent implementation
|
# For non-mocked node types, use parent implementation
|
||||||
|
|
|
||||||
|
|
@ -142,6 +142,8 @@ def test_mock_loop_node_preserves_config():
|
||||||
"start_node_id": "node1",
|
"start_node_id": "node1",
|
||||||
"loop_variables": [],
|
"loop_variables": [],
|
||||||
"outputs": {},
|
"outputs": {},
|
||||||
|
"break_conditions": [],
|
||||||
|
"logical_operator": "and",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,6 @@ class TestMockTemplateTransformNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
mock_node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = mock_node._run()
|
result = mock_node._run()
|
||||||
|
|
@ -125,7 +124,6 @@ class TestMockTemplateTransformNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
mock_node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = mock_node._run()
|
result = mock_node._run()
|
||||||
|
|
@ -184,7 +182,6 @@ class TestMockTemplateTransformNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
mock_node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = mock_node._run()
|
result = mock_node._run()
|
||||||
|
|
@ -246,7 +243,6 @@ class TestMockTemplateTransformNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
mock_node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = mock_node._run()
|
result = mock_node._run()
|
||||||
|
|
@ -311,7 +307,6 @@ class TestMockCodeNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
mock_node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = mock_node._run()
|
result = mock_node._run()
|
||||||
|
|
@ -376,7 +371,6 @@ class TestMockCodeNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
mock_node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = mock_node._run()
|
result = mock_node._run()
|
||||||
|
|
@ -445,7 +439,6 @@ class TestMockCodeNode:
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
mock_config=mock_config,
|
mock_config=mock_config,
|
||||||
)
|
)
|
||||||
mock_node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = mock_node._run()
|
result = mock_node._run()
|
||||||
|
|
|
||||||
|
|
@ -83,9 +83,6 @@ def test_execute_answer():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Mock db.session.close()
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,7 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
|
|
||||||
# Ensures that all node classes are imported.
|
# Ensures that all node classes are imported.
|
||||||
|
|
@ -7,6 +10,12 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
_ = NODE_TYPE_CLASSES_MAPPING
|
_ = NODE_TYPE_CLASSES_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
class _TestNodeData(BaseNodeData):
|
||||||
|
"""Test node data for unit tests."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_all_subclasses(root: type[Node]) -> list[type[Node]]:
|
def _get_all_subclasses(root: type[Node]) -> list[type[Node]]:
|
||||||
subclasses = []
|
subclasses = []
|
||||||
queue = [root]
|
queue = [root]
|
||||||
|
|
@ -34,3 +43,79 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||||
node_type_and_version = (node_type, node_version)
|
node_type_and_version = (node_type, node_version)
|
||||||
assert node_type_and_version not in type_version_set
|
assert node_type_and_version not in type_version_set
|
||||||
type_version_set.add(node_type_and_version)
|
type_version_set.add(node_type_and_version)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_node_data_type_from_generic_extracts_type():
|
||||||
|
"""When a class inherits from Node[T], it should extract T."""
|
||||||
|
|
||||||
|
class _ConcreteNode(Node[_TestNodeData]):
|
||||||
|
node_type = NodeType.CODE
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def version() -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
result = _ConcreteNode._extract_node_data_type_from_generic()
|
||||||
|
|
||||||
|
assert result is _TestNodeData
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_node_data_type_from_generic_returns_none_for_base_node():
|
||||||
|
"""The base Node class itself should return None (no generic parameter)."""
|
||||||
|
result = Node._extract_node_data_type_from_generic()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_node_data_type_from_generic_raises_for_non_base_node_data():
|
||||||
|
"""When generic parameter is not a BaseNodeData subtype, should raise TypeError."""
|
||||||
|
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
|
||||||
|
|
||||||
|
class _InvalidNode(Node[str]): # type: ignore[type-arg]
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_node_data_type_from_generic_raises_for_non_type():
|
||||||
|
"""When generic parameter is not a concrete type, should raise TypeError."""
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
|
||||||
|
|
||||||
|
class _InvalidNode(Node[T]): # type: ignore[type-arg]
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_subclass_raises_without_generic_or_explicit_type():
|
||||||
|
"""A subclass must either use Node[T] or explicitly set _node_data_type."""
|
||||||
|
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
|
||||||
|
|
||||||
|
class _InvalidNode(Node):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_subclass_rejects_explicit_node_data_type_without_generic():
|
||||||
|
"""Setting _node_data_type explicitly cannot bypass the Node[T] requirement."""
|
||||||
|
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
|
||||||
|
|
||||||
|
class _ExplicitNode(Node):
|
||||||
|
_node_data_type = _TestNodeData
|
||||||
|
node_type = NodeType.CODE
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def version() -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_subclass_sets_node_data_type_from_generic():
|
||||||
|
"""Verify that __init_subclass__ sets _node_data_type from the generic parameter."""
|
||||||
|
|
||||||
|
class _AutoNode(Node[_TestNodeData]):
|
||||||
|
node_type = NodeType.CODE
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def version() -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
assert _AutoNode._node_data_type is _TestNodeData
|
||||||
|
|
|
||||||
|
|
@ -111,8 +111,6 @@ def llm_node(
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
llm_file_saver=mock_file_saver,
|
llm_file_saver=mock_file_saver,
|
||||||
)
|
)
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -498,8 +496,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
llm_file_saver=mock_file_saver,
|
llm_file_saver=mock_file_saver,
|
||||||
)
|
)
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
return node, mock_file_saver
|
return node, mock_file_saver
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.workflow.entities import GraphInitParams
|
||||||
|
from core.workflow.enums import NodeType
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
|
from core.workflow.nodes.base.node import Node
|
||||||
|
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||||
|
from core.workflow.system_variable import SystemVariable
|
||||||
|
|
||||||
|
|
||||||
|
class _SampleNodeData(BaseNodeData):
|
||||||
|
foo: str
|
||||||
|
|
||||||
|
|
||||||
|
class _SampleNode(Node[_SampleNodeData]):
|
||||||
|
node_type = NodeType.ANSWER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "sample-test"
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||||
|
init_params = GraphInitParams(
|
||||||
|
tenant_id="tenant",
|
||||||
|
app_id="app",
|
||||||
|
workflow_id="workflow",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="user",
|
||||||
|
user_from="account",
|
||||||
|
invoke_from="debugger",
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
runtime_state = GraphRuntimeState(
|
||||||
|
variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}),
|
||||||
|
start_at=0.0,
|
||||||
|
)
|
||||||
|
return init_params, runtime_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_hydrates_data_during_initialization():
|
||||||
|
graph_config: dict[str, object] = {}
|
||||||
|
init_params, runtime_state = _build_context(graph_config)
|
||||||
|
|
||||||
|
node = _SampleNode(
|
||||||
|
id="node-1",
|
||||||
|
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
|
||||||
|
graph_init_params=init_params,
|
||||||
|
graph_runtime_state=runtime_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node.node_data.foo == "bar"
|
||||||
|
assert node.title == "Sample"
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_generic_argument_raises_type_error():
|
||||||
|
graph_config: dict[str, object] = {}
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
|
||||||
|
class _InvalidNode(Node): # type: ignore[type-abstract]
|
||||||
|
node_type = NodeType.ANSWER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
@ -50,8 +50,6 @@ def document_extractor_node(graph_init_params):
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=Mock(),
|
graph_runtime_state=Mock(),
|
||||||
)
|
)
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -114,9 +114,6 @@ def test_execute_if_else_result_true():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Mock db.session.close()
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
|
|
@ -187,9 +184,6 @@ def test_execute_if_else_result_false():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Mock db.session.close()
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
|
|
@ -252,9 +246,6 @@ def test_array_file_contains_file_name():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
||||||
value=[
|
value=[
|
||||||
File(
|
File(
|
||||||
|
|
@ -347,7 +338,6 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
config={"id": "if-else", "data": node_data},
|
config={"id": "if-else", "data": node_data},
|
||||||
)
|
)
|
||||||
node.init_node_data(node_data)
|
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Mock db.session.close()
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
@ -417,7 +407,6 @@ def test_execute_if_else_boolean_false_conditions():
|
||||||
"data": node_data,
|
"data": node_data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
node.init_node_data(node_data)
|
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Mock db.session.close()
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
@ -487,7 +476,6 @@ def test_execute_if_else_boolean_cases_structure():
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
config={"id": "if-else", "data": node_data},
|
config={"id": "if-else", "data": node_data},
|
||||||
)
|
)
|
||||||
node.init_node_data(node_data)
|
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Mock db.session.close()
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,6 @@ def list_operator_node():
|
||||||
graph_init_params=graph_init_params,
|
graph_init_params=graph_init_params,
|
||||||
graph_runtime_state=MagicMock(),
|
graph_runtime_state=MagicMock(),
|
||||||
)
|
)
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
node.graph_runtime_state = MagicMock()
|
node.graph_runtime_state = MagicMock()
|
||||||
node.graph_runtime_state.variable_pool = MagicMock()
|
node.graph_runtime_state.variable_pool = MagicMock()
|
||||||
return node
|
return node
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,6 @@ def tool_node(monkeypatch) -> "ToolNode":
|
||||||
graph_init_params=init_params,
|
graph_init_params=init_params,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
)
|
)
|
||||||
node.init_node_data(config["data"])
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -101,9 +101,6 @@ def test_overwrite_string_variable():
|
||||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
list(node.run())
|
list(node.run())
|
||||||
expected_var = StringVariable(
|
expected_var = StringVariable(
|
||||||
id=conversation_variable.id,
|
id=conversation_variable.id,
|
||||||
|
|
@ -203,9 +200,6 @@ def test_append_variable_to_array():
|
||||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
list(node.run())
|
list(node.run())
|
||||||
expected_value = list(conversation_variable.value)
|
expected_value = list(conversation_variable.value)
|
||||||
expected_value.append(input_variable.value)
|
expected_value.append(input_variable.value)
|
||||||
|
|
@ -296,9 +290,6 @@ def test_clear_array():
|
||||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
list(node.run())
|
list(node.run())
|
||||||
expected_var = ArrayStringVariable(
|
expected_var = ArrayStringVariable(
|
||||||
id=conversation_variable.id,
|
id=conversation_variable.id,
|
||||||
|
|
|
||||||
|
|
@ -139,11 +139,6 @@ def test_remove_first_from_array():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Skip the mock assertion since we're in a test environment
|
|
||||||
|
|
||||||
# Run the node
|
# Run the node
|
||||||
result = list(node.run())
|
result = list(node.run())
|
||||||
|
|
||||||
|
|
@ -228,10 +223,6 @@ def test_remove_last_from_array():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Skip the mock assertion since we're in a test environment
|
|
||||||
list(node.run())
|
list(node.run())
|
||||||
|
|
||||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||||
|
|
@ -313,10 +304,6 @@ def test_remove_first_from_empty_array():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Skip the mock assertion since we're in a test environment
|
|
||||||
list(node.run())
|
list(node.run())
|
||||||
|
|
||||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||||
|
|
@ -398,10 +385,6 @@ def test_remove_last_from_empty_array():
|
||||||
config=node_config,
|
config=node_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize node data
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
|
|
||||||
# Skip the mock assertion since we're in a test environment
|
|
||||||
list(node.run())
|
list(node.run())
|
||||||
|
|
||||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
node.init_node_data(node_config["data"])
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue