From 989db0e584e039185e3a6edc3579857042695734 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Wed, 11 Mar 2026 23:43:58 +0800 Subject: [PATCH] refactor: Unify NodeConfigDict.data and BaseNodeData (#32780) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/core/app/apps/workflow_app_runner.py | 8 +- api/core/trigger/debug/event_selectors.py | 7 +- api/core/workflow/node_factory.py | 291 +++++++----------- api/core/workflow/workflow_entry.py | 12 +- api/dify_graph/entities/base_node_data.py | 176 +++++++++++ .../{nodes/base => entities}/exc.py | 0 api/dify_graph/entities/graph_config.py | 11 +- api/dify_graph/graph/graph.py | 31 +- api/dify_graph/nodes/agent/agent_node.py | 7 +- api/dify_graph/nodes/agent/entities.py | 4 +- api/dify_graph/nodes/answer/answer_node.py | 8 +- api/dify_graph/nodes/answer/entities.py | 4 +- api/dify_graph/nodes/base/__init__.py | 3 +- api/dify_graph/nodes/base/entities.py | 131 +------- api/dify_graph/nodes/base/node.py | 51 ++- api/dify_graph/nodes/code/code_node.py | 10 +- api/dify_graph/nodes/code/entities.py | 5 +- .../nodes/datasource/datasource_node.py | 12 +- api/dify_graph/nodes/datasource/entities.py | 5 +- .../nodes/document_extractor/entities.py | 4 +- .../nodes/document_extractor/node.py | 11 +- api/dify_graph/nodes/end/entities.py | 5 +- api/dify_graph/nodes/http_request/entities.py | 4 +- api/dify_graph/nodes/http_request/node.py | 20 +- api/dify_graph/nodes/human_input/entities.py | 4 +- .../nodes/human_input/human_input_node.py | 8 +- api/dify_graph/nodes/if_else/entities.py | 5 +- api/dify_graph/nodes/if_else/if_else_node.py | 8 +- api/dify_graph/nodes/iteration/entities.py | 7 +- .../nodes/iteration/iteration_node.py | 19 +- .../nodes/knowledge_index/entities.py | 5 +- .../knowledge_index/knowledge_index_node.py | 3 +- .../nodes/knowledge_retrieval/entities.py | 5 +- .../knowledge_retrieval_node.py | 16 +- .../nodes/list_operator/entities.py | 4 +- api/dify_graph/nodes/llm/entities.py | 4 +- api/dify_graph/nodes/llm/node.py | 26 +- api/dify_graph/nodes/loop/entities.py | 9 +- api/dify_graph/nodes/loop/loop_node.py | 15 +- .../nodes/parameter_extractor/entities.py | 4 +- .../parameter_extractor_node.py | 15 +- .../nodes/question_classifier/entities.py | 4 +- .../question_classifier_node.py | 14 +- api/dify_graph/nodes/start/entities.py | 4 +- .../nodes/template_transform/entities.py | 4 +- .../template_transform_node.py | 10 +- api/dify_graph/nodes/tool/entities.py | 5 +- api/dify_graph/nodes/tool/tool_node.py | 10 +- .../nodes/trigger_plugin/entities.py | 7 +- .../nodes/trigger_schedule/entities.py | 4 +- api/dify_graph/nodes/trigger_schedule/exc.py | 2 +- .../nodes/trigger_webhook/entities.py | 81 ++++- api/dify_graph/nodes/trigger_webhook/exc.py | 2 +- api/dify_graph/nodes/trigger_webhook/node.py | 2 +- .../nodes/variable_aggregator/entities.py | 4 +- .../nodes/variable_assigner/v1/node.py | 18 +- .../nodes/variable_assigner/v1/node_data.py | 4 +- .../nodes/variable_assigner/v2/entities.py | 4 +- .../nodes/variable_assigner/v2/node.py | 10 +- api/models/workflow.py | 14 +- api/services/trigger/schedule_service.py | 40 ++- api/services/trigger/trigger_service.py | 5 +- api/services/trigger/webhook_service.py | 169 +++++----- api/services/workflow_service.py | 13 +- .../workflow/nodes/test_http.py | 2 + .../services/test_webhook_service.py | 2 +- .../unit_tests/configs/test_dify_config.py | 9 +- .../core/app/apps/test_pause_resume.py | 18 +- .../test_workflow_app_runner_single_node.py | 56 ++++ .../workflow/graph/test_graph_validation.py | 35 ++- .../event_management/test_event_handlers.py | 2 +- .../graph_engine/test_graph_engine.py | 2 +- .../graph_engine/test_mock_factory.py | 34 +- .../workflow/nodes/base/test_base_node.py | 19 +- .../test_get_node_type_classes_mapping.py | 2 +- .../workflow/nodes/code/code_node_spec.py | 12 +- .../nodes/iteration/iteration_node_spec.py | 48 +++ .../test_knowledge_retrieval_node.py | 16 +- .../core/workflow/nodes/test_base_node.py | 34 +- .../core/workflow/nodes/test_loop_node.py | 52 ++++ .../workflow/nodes/webhook/test_entities.py | 35 ++- .../workflow/nodes/webhook/test_exceptions.py | 2 +- .../core/workflow/test_node_factory.py | 82 +++++ .../core/workflow/test_workflow_entry.py | 3 +- .../test_workflow_human_input_delivery.py | 5 +- .../workflow/test_workflow_service.py | 12 +- 86 files changed, 1172 insertions(+), 717 deletions(-) create mode 100644 api/dify_graph/entities/base_node_data.py rename api/dify_graph/{nodes/base => entities}/exc.py (100%) create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_loop_node.py create mode 100644 api/tests/unit_tests/core/workflow/test_node_factory.py diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 7ef6ff7cc2..3812fac28a 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -32,6 +32,7 @@ from core.app.entities.queue_entities import ( from core.workflow.node_factory import DifyNodeFactory from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer @@ -62,7 +63,6 @@ from dify_graph.graph_events import ( NodeRunSucceededEvent, ) from dify_graph.graph_events.graph import GraphRunAbortedEvent -from dify_graph.nodes import NodeType from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -303,9 +303,11 @@ class WorkflowBasedAppRunner: if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") + target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) + # Get node class - node_type = NodeType(target_node_config.get("data", {}).get("type")) - node_version = target_node_config.get("data", {}).get("version", "1") + node_type = target_node_config["data"].type + node_version = str(target_node_config["data"].version) node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] # Use the variable pool from graph_runtime_state instead of creating a new one diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 9b7b3de614..442a2434d5 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -19,6 +19,7 @@ from core.trigger.debug.events import ( build_plugin_pool_key, build_webhook_pool_key, ) +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig @@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC): app_id: str user_id: str tenant_id: str - node_config: Mapping[str, Any] + node_config: NodeConfigDict node_id: str - def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str): + def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str): self.tenant_id = tenant_id self.user_id = user_id self.app_id = app_id @@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller): def poll(self) -> TriggerDebugEvent | None: from services.trigger.trigger_service import TriggerService - plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {})) + plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True) provider_id = TriggerProviderID(plugin_trigger_data.provider_id) pool_key: str = build_plugin_pool_key( name=plugin_trigger_data.event_name, diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 8c6b1dedee..d7f2a67c06 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -1,5 +1,5 @@ -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, cast, final +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, TypeAlias, cast, final from sqlalchemy import select from sqlalchemy.orm import Session @@ -22,7 +22,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.summary_index.summary_index import SummaryIndex from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.tools.tool_file_manager import ToolFileManager -from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import NodeType, SystemVariableKey from dify_graph.file.file_manager import file_manager @@ -31,26 +32,19 @@ from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.code_node import CodeNode, WorkflowCodeExecutor +from dify_graph.nodes.code.code_node import WorkflowCodeExecutor from dify_graph.nodes.code.entities import CodeLanguage from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.nodes.datasource import DatasourceNode -from dify_graph.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig -from dify_graph.nodes.http_request import HttpRequestNode, build_http_request_config -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode -from dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.nodes.llm.entities import ModelConfig +from dify_graph.nodes.document_extractor import UnstructuredApiConfig +from dify_graph.nodes.http_request import build_http_request_config +from dify_graph.nodes.llm.entities import LLMNodeData from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.node import LLMNode from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData from dify_graph.nodes.template_transform.template_renderer import ( CodeExecutorJinja2TemplateRenderer, ) -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -60,6 +54,9 @@ if TYPE_CHECKING: from dify_graph.runtime import GraphRuntimeState +LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData + + def fetch_memory( *, conversation_id: str | None, @@ -157,178 +154,128 @@ class DifyNodeFactory(NodeFactory): return DifyRunContext.model_validate(raw_ctx) @override - def create_node(self, node_config: NodeConfigDict) -> Node: + def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ Create a Node instance from node configuration data using the traditional mapping. :param node_config: node configuration dictionary containing type and other data :return: initialized Node instance - :raises ValueError: if node type is unknown or configuration is invalid + :raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation + (including pydantic ValidationError, which subclasses ValueError), + if node type is unknown, or if no implementation exists for the resolved version """ - # Get node_id from config - node_id = node_config["id"] + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_id = typed_node_config["id"] + node_data = typed_node_config["data"] + node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + node_type = node_data.type + node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { + NodeType.CODE: lambda: { + "code_executor": self._code_executor, + "code_limits": self._code_limits, + }, + NodeType.TEMPLATE_TRANSFORM: lambda: { + "template_renderer": self._template_renderer, + "max_output_length": self._template_transform_max_output_length, + }, + NodeType.HTTP_REQUEST: lambda: { + "http_request_config": self._http_request_config, + "http_client": self._http_request_http_client, + "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "file_manager": self._http_request_file_manager, + }, + NodeType.HUMAN_INPUT: lambda: { + "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + }, + NodeType.KNOWLEDGE_INDEX: lambda: { + "index_processor": IndexProcessor(), + "summary_index_service": SummaryIndex(), + }, + NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=True, + ), + NodeType.DATASOURCE: lambda: { + "datasource_manager": DatasourceManager, + }, + NodeType.KNOWLEDGE_RETRIEVAL: lambda: { + "rag_retrieval": self._rag_retrieval, + }, + NodeType.DOCUMENT_EXTRACTOR: lambda: { + "unstructured_api_config": self._document_extractor_unstructured_api_config, + "http_client": self._http_request_http_client, + }, + NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=True, + ), + NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=False, + ), + NodeType.TOOL: lambda: { + "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + }, + } + node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() + return node_class( + id=node_id, + config=typed_node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + **node_init_kwargs, + ) - # Get node type from config - node_data = node_config["data"] - try: - node_type = NodeType(node_data["type"]) - except ValueError: - raise ValueError(f"Unknown node type: {node_data['type']}") + @staticmethod + def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData: + """ + Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. + """ + return node_class.validate_node_data(node_data) - # Get node class + @staticmethod + def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") latest_node_class = node_mapping.get(LATEST_VERSION) - node_version = str(node_data.get("version", "1")) matched_node_class = node_mapping.get(node_version) node_class = matched_node_class or latest_node_class if not node_class: raise ValueError(f"No latest version class found for node type: {node_type}") + return node_class - # Create node instance - if node_type == NodeType.CODE: - return CodeNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - code_executor=self._code_executor, - code_limits=self._code_limits, - ) - - if node_type == NodeType.TEMPLATE_TRANSFORM: - return TemplateTransformNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - template_renderer=self._template_renderer, - max_output_length=self._template_transform_max_output_length, - ) - - if node_type == NodeType.HTTP_REQUEST: - return HttpRequestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - http_request_config=self._http_request_config, - http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, - file_manager=self._http_request_file_manager, - ) - - if node_type == NodeType.HUMAN_INPUT: - return HumanInputNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), - ) - - if node_type == NodeType.KNOWLEDGE_INDEX: - return KnowledgeIndexNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - index_processor=IndexProcessor(), - summary_index_service=SummaryIndex(), - ) - - if node_type == NodeType.LLM: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return LLMNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - http_client=self._http_request_http_client, - ) - - if node_type == NodeType.DATASOURCE: - return DatasourceNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - datasource_manager=DatasourceManager, - ) - - if node_type == NodeType.KNOWLEDGE_RETRIEVAL: - return KnowledgeRetrievalNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - rag_retrieval=self._rag_retrieval, - ) - - if node_type == NodeType.DOCUMENT_EXTRACTOR: - return DocumentExtractorNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - unstructured_api_config=self._document_extractor_unstructured_api_config, - http_client=self._http_request_http_client, - ) - - if node_type == NodeType.QUESTION_CLASSIFIER: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return QuestionClassifierNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - http_client=self._http_request_http_client, - ) - - if node_type == NodeType.PARAMETER_EXTRACTOR: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return ParameterExtractorNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - ) - - if node_type == NodeType.TOOL: - return ToolNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - tool_file_manager_factory=self._http_request_tool_file_manager_factory(), - ) - - return node_class( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, + def _build_llm_compatible_node_init_kwargs( + self, + *, + node_class: type[Node], + node_data: BaseNodeData, + include_http_client: bool, + ) -> dict[str, object]: + validated_node_data = cast( + LLMCompatibleNodeData, + self._validate_resolved_node_data(node_class=node_class, node_data=node_data), ) + model_instance = self._build_model_instance_for_llm_node(validated_node_data) + node_init_kwargs: dict[str, object] = { + "credentials_provider": self._llm_credentials_provider, + "model_factory": self._llm_model_factory, + "model_instance": model_instance, + "memory": self._build_memory_for_llm_node( + node_data=validated_node_data, + model_instance=model_instance, + ), + } + if include_http_client: + node_init_kwargs["http_client"] = self._http_request_http_client + return node_init_kwargs - def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance: - node_data_model = ModelConfig.model_validate(node_data["model"]) + def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: + node_data_model = node_data.model if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") @@ -364,14 +311,12 @@ class DifyNodeFactory(NodeFactory): def _build_memory_for_llm_node( self, *, - node_data: Mapping[str, Any], + node_data: LLMCompatibleNodeData, model_instance: ModelInstance, ) -> PromptMessageMemory | None: - raw_memory_config = node_data.get("memory") - if raw_memory_config is None: + if node_data.memory is None: return None - node_memory = MemoryConfig.model_validate(raw_memory_config) conversation_id_variable = self.graph_runtime_state.variable_pool.get( ["sys", SystemVariableKey.CONVERSATION_ID] ) @@ -381,6 +326,6 @@ class DifyNodeFactory(NodeFactory): return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, - node_data_memory=node_memory, + node_data_memory=node_data.memory, model_instance=model_instance, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index c259e7ac08..fef01049c5 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -11,7 +11,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer from core.workflow.node_factory import DifyNodeFactory from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigData, NodeConfigDict +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.file.models import File from dify_graph.graph import Graph @@ -212,7 +212,7 @@ class WorkflowEntry: node_config_data = node_config["data"] # Get node type - node_type = NodeType(node_config_data["type"]) + node_type = node_config_data.type # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -234,8 +234,7 @@ class WorkflowEntry: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - typed_node_config = cast(dict[str, object], node_config) - node = cast(Any, node_factory).create_node(typed_node_config) + node = node_factory.create_node(node_config) node_cls = type(node) try: @@ -371,10 +370,7 @@ class WorkflowEntry: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node_config: NodeConfigDict = { - "id": node_id, - "data": cast(NodeConfigData, node_data), - } + node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py new file mode 100644 index 0000000000..58869a94c2 --- /dev/null +++ b/api/dify_graph/entities/base_node_data.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import json +from abc import ABC +from builtins import type as type_ +from enum import StrEnum +from typing import Any, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from dify_graph.entities.exc import DefaultValueTypeError +from dify_graph.enums import ErrorStrategy, NodeType + +# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. +_NumberType = Union[int, float] + + +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds + retry_enabled: bool = False # whether retry is enabled + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 + + +class DefaultValueType(StrEnum): + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_NUMBER = "array[number]" + ARRAY_STRING = "array[string]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILES = "array[file]" + + +class DefaultValue(BaseModel): + value: Any = None + type: DefaultValueType + key: str + + @staticmethod + def _parse_json(value: str): + """Unified JSON parsing handler""" + try: + return json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") + + @staticmethod + def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: + """Unified array type validation""" + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + + @staticmethod + def _convert_number(value: str) -> float: + """Unified number conversion handler""" + try: + return float(value) + except ValueError: + raise DefaultValueTypeError(f"Cannot convert to number: {value}") + + @model_validator(mode="after") + def validate_value_type(self) -> DefaultValue: + # Type validation configuration + type_validators: dict[DefaultValueType, dict[str, Any]] = { + DefaultValueType.STRING: { + "type": str, + "converter": lambda x: x, + }, + DefaultValueType.NUMBER: { + "type": _NumberType, + "converter": self._convert_number, + }, + DefaultValueType.OBJECT: { + "type": dict, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_NUMBER: { + "type": list, + "element_type": _NumberType, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_STRING: { + "type": list, + "element_type": str, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_OBJECT: { + "type": list, + "element_type": dict, + "converter": self._parse_json, + }, + } + + validator: dict[str, Any] = type_validators.get(self.type, {}) + if not validator: + if self.type == DefaultValueType.ARRAY_FILES: + # Handle files type + return self + raise DefaultValueTypeError(f"Unsupported type: {self.type}") + + # Handle string input cases + if isinstance(self.value, str) and self.type != DefaultValueType.STRING: + self.value = validator["converter"](self.value) + + # Validate base type + if not isinstance(self.value, validator["type"]): + raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") + + # Validate array element types + if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): + raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") + + return self + + +class BaseNodeData(ABC, BaseModel): + # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where + # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. + # At that boundary, node-specific fields are still "extra" relative to this shared DTO, + # and persisted templates/workflows also carry undeclared compatibility keys such as + # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive + # here until graph parsing becomes discriminated by node type or those legacy payloads + # are normalized. + model_config = ConfigDict(extra="allow") + + type: NodeType + title: str = "" + desc: str | None = None + version: str = "1" + error_strategy: ErrorStrategy | None = None + default_value: list[DefaultValue] | None = None + retry_config: RetryConfig = Field(default_factory=RetryConfig) + + @property + def default_value_dict(self) -> dict[str, Any]: + if self.default_value: + return {item.key: item.value for item in self.default_value} + return {} + + def __getitem__(self, key: str) -> Any: + """ + Dict-style access without calling model_dump() on every lookup. + Prefer using model fields and Pydantic's extra storage. + """ + # First, check declared model fields + if key in self.__class__.model_fields: + return getattr(self, key) + + # Then, check undeclared compatibility fields stored in Pydantic's extra dict. + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras[key] + + raise KeyError(key) + + def get(self, key: str, default: Any = None) -> Any: + """ + Dict-style .get() without calling model_dump() on every lookup. + """ + if key in self.__class__.model_fields: + return getattr(self, key) + + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras.get(key, default) + + return default diff --git a/api/dify_graph/nodes/base/exc.py b/api/dify_graph/entities/exc.py similarity index 100% rename from api/dify_graph/nodes/base/exc.py rename to api/dify_graph/entities/exc.py diff --git a/api/dify_graph/entities/graph_config.py b/api/dify_graph/entities/graph_config.py index 209dcfe6bc..36f7b94e82 100644 --- a/api/dify_graph/entities/graph_config.py +++ b/api/dify_graph/entities/graph_config.py @@ -4,21 +4,20 @@ import sys from pydantic import TypeAdapter, with_config +from dify_graph.entities.base_node_data import BaseNodeData + if sys.version_info >= (3, 12): from typing import TypedDict else: from typing_extensions import TypedDict -@with_config(extra="allow") -class NodeConfigData(TypedDict): - type: str - - @with_config(extra="allow") class NodeConfigDict(TypedDict): id: str - data: NodeConfigData + # This is the permissive raw graph boundary. Node factories re-validate `data` + # with the concrete `NodeData` subtype after resolving the node implementation. + data: BaseNodeData NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/dify_graph/graph/graph.py b/api/dify_graph/graph/graph.py index 3fe94eb3fd..3eb6bfc359 100644 --- a/api/dify_graph/graph/graph.py +++ b/api/dify_graph/graph/graph.py @@ -8,7 +8,7 @@ from typing import Protocol, cast, final from pydantic import TypeAdapter from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState from dify_graph.nodes.base.node import Node from libs.typing import is_str @@ -34,7 +34,8 @@ class NodeFactory(Protocol): :param node_config: node configuration dictionary containing type and other data :return: initialized Node instance - :raises ValueError: if node type is unknown or configuration is invalid + :raises ValueError: if node type is unknown or no implementation exists for the resolved version + :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation """ ... @@ -115,10 +116,7 @@ class Graph: start_node_id = None for nid in root_candidates: node_data = node_configs_map[nid]["data"] - node_type = node_data["type"] - if not isinstance(node_type, str): - continue - if NodeType(node_type).is_start_node: + if node_data.type.is_start_node: start_node_id = nid break @@ -203,6 +201,23 @@ class Graph: return GraphBuilder(graph_cls=cls) + @staticmethod + def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: + """ + Remove editor-only nodes before `NodeConfigDict` validation. + + Persisted note widgets use a top-level `type == "custom-note"` but leave + `data.type` empty because they are never executable graph nodes. Filter + them while configs are still raw dicts so Pydantic does not validate + their placeholder payloads against `BaseNodeData.type: NodeType`. + """ + filtered_node_configs: list[dict[str, object]] = [] + for node_config in node_configs: + if node_config.get("type", "") == "custom-note": + continue + filtered_node_configs.append(dict(node_config)) + return filtered_node_configs + @classmethod def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: """ @@ -302,13 +317,13 @@ class Graph: node_configs = graph_config.get("nodes", []) edge_configs = cast(list[dict[str, object]], edge_configs) + node_configs = cast(list[dict[str, object]], node_configs) + node_configs = cls._filter_canvas_only_nodes(node_configs) node_configs = _ListNodeConfigDict.validate_python(node_configs) if not node_configs: raise ValueError("Graph must have at least one node") - node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"] - # Parse node configurations node_configs_map = cls._parse_node_configs(node_configs) diff --git a/api/dify_graph/nodes/agent/agent_node.py b/api/dify_graph/nodes/agent/agent_node.py index d770f7afd1..d501217454 100644 --- a/api/dify_graph/nodes/agent/agent_node.py +++ b/api/dify_graph/nodes/agent/agent_node.py @@ -374,12 +374,11 @@ class AgentNode(Node[AgentNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: AgentNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = AgentNodeData.model_validate(node_data) - + _ = graph_config # Explicitly mark as unused result: dict[str, Any] = {} + typed_node_data = node_data for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] match input.type: diff --git a/api/dify_graph/nodes/agent/entities.py b/api/dify_graph/nodes/agent/entities.py index 9124420f01..f7b7af8fa4 100644 --- a/api/dify_graph/nodes/agent/entities.py +++ b/api/dify_graph/nodes/agent/entities.py @@ -5,10 +5,12 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class AgentNodeData(BaseNodeData): + type: NodeType = NodeType.AGENT agent_strategy_provider_name: str # redundancy agent_strategy_name: str agent_strategy_label: str # redundancy diff --git a/api/dify_graph/nodes/answer/answer_node.py b/api/dify_graph/nodes/answer/answer_node.py index d07b9c8062..c829b892cc 100644 --- a/api/dify_graph/nodes/answer/answer_node.py +++ b/api/dify_graph/nodes/answer/answer_node.py @@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: AnswerNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = AnswerNodeData.model_validate(node_data) - - variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) + _ = graph_config # Explicitly mark as unused + variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} diff --git a/api/dify_graph/nodes/answer/entities.py b/api/dify_graph/nodes/answer/entities.py index 06927cd71e..3cc1d6572e 100644 --- a/api/dify_graph/nodes/answer/entities.py +++ b/api/dify_graph/nodes/answer/entities.py @@ -3,7 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class AnswerNodeData(BaseNodeData): @@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData): Answer Node Data. """ + type: NodeType = NodeType.ANSWER answer: str = Field(..., description="answer template string") diff --git a/api/dify_graph/nodes/base/__init__.py b/api/dify_graph/nodes/base/__init__.py index f83df0e323..036e25895d 100644 --- a/api/dify_graph/nodes/base/__init__.py +++ b/api/dify_graph/nodes/base/__init__.py @@ -1,4 +1,4 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData +from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState from .usage_tracking_mixin import LLMUsageTrackingMixin __all__ = [ @@ -6,6 +6,5 @@ __all__ = [ "BaseIterationState", "BaseLoopNodeData", "BaseLoopState", - "BaseNodeData", "LLMUsageTrackingMixin", ] diff --git a/api/dify_graph/nodes/base/entities.py b/api/dify_graph/nodes/base/entities.py index 956fa59e78..4f8b2682e1 100644 --- a/api/dify_graph/nodes/base/entities.py +++ b/api/dify_graph/nodes/base/entities.py @@ -1,31 +1,12 @@ from __future__ import annotations -import json -from abc import ABC -from builtins import type as type_ from collections.abc import Sequence from enum import StrEnum -from typing import Any, Union +from typing import Any -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, field_validator -from dify_graph.enums import ErrorStrategy - -from .exc import DefaultValueTypeError - -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 +from dify_graph.entities.base_node_data import BaseNodeData class VariableSelector(BaseModel): @@ -76,112 +57,6 @@ class OutputVariableEntity(BaseModel): return v -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - title: str - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = RetryConfig() - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - class BaseIterationNodeData(BaseNodeData): start_node_id: str | None = None diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 1f99a0a6e2..d7702c8cb4 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -12,6 +12,8 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, ge from uuid import uuid4 from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import ( ErrorStrategy, @@ -62,8 +64,6 @@ from dify_graph.node_events import ( from dify_graph.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now -from .entities import BaseNodeData, RetryConfig - NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) _MISSING_RUN_CONTEXT_VALUE = object() @@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]): Later, in __init__: :: - config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate() - │ - ▼ - CodeNodeData instance - (stored in self._node_data) + config["data"] ──► _node_data_type.model_validate(..., from_attributes=True) + │ + ▼ + CodeNodeData instance + (stored in self._node_data) Example: class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted @@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, ) -> None: @@ -254,22 +254,21 @@ class Node(Generic[NodeDataT]): self.graph_runtime_state = graph_runtime_state self.state: NodeState = NodeState.UNKNOWN # node execution state - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required.") + node_id = config["id"] self._node_id = node_id self._node_execution_id: str = "" self._start_at = naive_utc_now() - raw_node_data = config.get("data") or {} - if not isinstance(raw_node_data, Mapping): - raise ValueError("Node config data must be a mapping.") - - self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data) + self._node_data = self.validate_node_data(config["data"]) self.post_init() + @classmethod + def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: + """Validate shared graph node payloads against the subclass-declared NodeData model.""" + return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) + def post_init(self) -> None: """Optional hook for subclasses requiring extra initialization.""" return @@ -342,9 +341,6 @@ class Node(Generic[NodeDataT]): return None return str(execution_id) - def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: - return cast(NodeDataT, self._node_data_type.model_validate(data)) - @abstractmethod def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: """ @@ -389,8 +385,6 @@ class Node(Generic[NodeDataT]): start_event.provider_id = getattr(self.node_data, "provider_id", "") start_event.provider_type = getattr(self.node_data, "provider_type", "") - from typing import cast - from dify_graph.nodes.agent.agent_node import AgentNode from dify_graph.nodes.agent.entities import AgentNodeData @@ -442,7 +436,7 @@ class Node(Generic[NodeDataT]): cls, *, graph_config: Mapping[str, Any], - config: Mapping[str, Any], + config: NodeConfigDict, ) -> Mapping[str, Sequence[str]]: """Extracts references variable selectors from node configuration. @@ -480,13 +474,12 @@ class Node(Generic[NodeDataT]): :param config: node config :return: """ - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required when extracting variable selector to variable mapping.") - - # Pass raw dict data instead of creating NodeData instance + node_id = config["id"] + node_data = cls.validate_node_data(config["data"]) data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) + graph_config=graph_config, + node_id=node_id, + node_data=node_data, ) return data @@ -496,7 +489,7 @@ class Node(Generic[NodeDataT]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: NodeDataT, ) -> Mapping[str, Sequence[str]]: return {} diff --git a/api/dify_graph/nodes/code/code_node.py b/api/dify_graph/nodes/code/code_node.py index 83e72deea9..ac8d6463b9 100644 --- a/api/dify_graph/nodes/code/code_node.py +++ b/api/dify_graph/nodes/code/code_node.py @@ -3,6 +3,7 @@ from decimal import Decimal from textwrap import dedent from typing import TYPE_CHECKING, Any, Protocol, cast +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: CodeNodeData, ) -> Mapping[str, Sequence[str]]: _ = graph_config # Explicitly mark as unused - # Create typed NodeData from dict - typed_node_data = CodeNodeData.model_validate(node_data) - return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in typed_node_data.variables + for variable_selector in node_data.variables } @property diff --git a/api/dify_graph/nodes/code/entities.py b/api/dify_graph/nodes/code/entities.py index 9e161c29d0..25e46226e1 100644 --- a/api/dify_graph/nodes/code/entities.py +++ b/api/dify_graph/nodes/code/entities.py @@ -3,7 +3,8 @@ from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.base.entities import VariableSelector from dify_graph.variables.types import SegmentType @@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData): Code Node Data. """ + type: NodeType = NodeType.CODE + class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] children: dict[str, "CodeNodeData.Output"] | None = None diff --git a/api/dify_graph/nodes/datasource/datasource_node.py b/api/dify_graph/nodes/datasource/datasource_node.py index b97394744e..b0d3e1a24e 100644 --- a/api/dify_graph/nodes/datasource/datasource_node.py +++ b/api/dify_graph/nodes/datasource/datasource_node.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey from dify_graph.node_events import NodeRunResult, StreamCompletedEvent @@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", datasource_manager: DatasourceManagerProtocol, @@ -181,7 +182,7 @@ class DatasourceNode(Node[DatasourceNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: DatasourceNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -190,11 +191,10 @@ class DatasourceNode(Node[DatasourceNodeData]): :param node_data: node data :return: """ - typed_node_data = DatasourceNodeData.model_validate(node_data) result = {} - if typed_node_data.datasource_parameters: - for parameter_name in typed_node_data.datasource_parameters: - input = typed_node_data.datasource_parameters[parameter_name] + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + input = node_data.datasource_parameters[parameter_name] match input.type: case "mixed": assert isinstance(input.value, str) diff --git a/api/dify_graph/nodes/datasource/entities.py b/api/dify_graph/nodes/datasource/entities.py index ba49e65f31..38275ac158 100644 --- a/api/dify_graph/nodes/datasource/entities.py +++ b/api/dify_graph/nodes/datasource/entities.py @@ -3,7 +3,8 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class DatasourceEntity(BaseModel): @@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel): class DatasourceNodeData(BaseNodeData, DatasourceEntity): + type: NodeType = NodeType.DATASOURCE + class DatasourceInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] diff --git a/api/dify_graph/nodes/document_extractor/entities.py b/api/dify_graph/nodes/document_extractor/entities.py index f4949d0df8..9f42d2e605 100644 --- a/api/dify_graph/nodes/document_extractor/entities.py +++ b/api/dify_graph/nodes/document_extractor/entities.py @@ -1,10 +1,12 @@ from collections.abc import Sequence from dataclasses import dataclass -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class DocumentExtractorNodeData(BaseNodeData): + type: NodeType = NodeType.DOCUMENT_EXTRACTOR variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py index c26b18aac9..fe51b1963e 100644 --- a/api/dify_graph/nodes/document_extractor/node.py +++ b/api/dify_graph/nodes/document_extractor/node.py @@ -21,6 +21,7 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, file_manager from dify_graph.node_events import NodeRunResult @@ -54,7 +55,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -136,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: DocumentExtractorNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = DocumentExtractorNodeData.model_validate(node_data) - - return {node_id + ".files": typed_node_data.variable_selector} + _ = graph_config # Explicitly mark as unused + return {node_id + ".files": node_data.variable_selector} def _extract_text_by_mime_type( diff --git a/api/dify_graph/nodes/end/entities.py b/api/dify_graph/nodes/end/entities.py index a410087214..69cd1dd8f5 100644 --- a/api/dify_graph/nodes/end/entities.py +++ b/api/dify_graph/nodes/end/entities.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, Field -from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.nodes.base.entities import OutputVariableEntity class EndNodeData(BaseNodeData): @@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData): END Node Data. """ + type: NodeType = NodeType.END outputs: list[OutputVariableEntity] diff --git a/api/dify_graph/nodes/http_request/entities.py b/api/dify_graph/nodes/http_request/entities.py index a5564689f8..46e08ea1a0 100644 --- a/api/dify_graph/nodes/http_request/entities.py +++ b/api/dify_graph/nodes/http_request/entities.py @@ -8,7 +8,8 @@ import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" @@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ + type: NodeType = NodeType.HTTP_REQUEST method: Literal[ "get", "post", diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index 2e48d5502a..3895ae92c0 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -3,6 +3,7 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod from dify_graph.node_events import NodeRunResult @@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -163,18 +164,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: HttpRequestNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = HttpRequestNodeData.model_validate(node_data) - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params) - if typed_node_data.body: - body_type = typed_node_data.body.type - data = typed_node_data.body.data + selectors += variable_template_parser.extract_selectors_from_template(node_data.url) + selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(node_data.params) + if node_data.body: + body_type = node_data.body.type + data = node_data.body.data match body_type: case "none": pass diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py index 5616949dcc..82a8f32053 100644 --- a/api/dify_graph/nodes/human_input/entities.py +++ b/api/dify_graph/nodes/human_input/entities.py @@ -10,7 +10,8 @@ from typing import Annotated, Any, ClassVar, Literal, Self from pydantic import BaseModel, Field, field_validator, model_validator -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.runtime import VariablePool from dify_graph.variables.consts import SELECTORS_LENGTH @@ -214,6 +215,7 @@ class UserAction(BaseModel): class HumanInputNodeData(BaseNodeData): """Human Input node data.""" + type: NodeType = NodeType.HUMAN_INPUT delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index 03c2d17b1d..3a167d122b 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -3,6 +3,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import ( @@ -63,7 +64,7 @@ class HumanInputNode(Node[HumanInputNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", form_repository: HumanInputFormRepository, @@ -348,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: HumanInputNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selectors referenced in form content and input default values. @@ -357,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]): 1. Variables referenced in form_content ({{#node_name.var_name#}}) 2. Variables referenced in input default values """ - validated_node_data = HumanInputNodeData.model_validate(node_data) - return validated_node_data.extract_variable_selector_to_variable_mapping(node_id) + return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/dify_graph/nodes/if_else/entities.py b/api/dify_graph/nodes/if_else/entities.py index 4733944039..c9bb1cdc7f 100644 --- a/api/dify_graph/nodes/if_else/entities.py +++ b/api/dify_graph/nodes/if_else/entities.py @@ -2,7 +2,8 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.utils.condition.entities import Condition @@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData): If Else Node Data. """ + type: NodeType = NodeType.IF_ELSE + class Case(BaseModel): """ Case entity representing a single logical condition group diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py index 3c5a33e2b7..4b6d30c279 100644 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ b/api/dify_graph/nodes/if_else/if_else_node.py @@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: IfElseNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = IfElseNodeData.model_validate(node_data) - var_mapping: dict[str, list[str]] = {} - for case in typed_node_data.cases or []: + _ = graph_config # Explicitly mark as unused + for case in node_data.cases or []: for condition in case.conditions: key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" var_mapping[key] = condition.variable_selector diff --git a/api/dify_graph/nodes/iteration/entities.py b/api/dify_graph/nodes/iteration/entities.py index a31b05463e..6d61c12352 100644 --- a/api/dify_graph/nodes/iteration/entities.py +++ b/api/dify_graph/nodes/iteration/entities.py @@ -3,7 +3,9 @@ from typing import Any from pydantic import Field -from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState class ErrorHandleMode(StrEnum): @@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ + type: NodeType = NodeType.ITERATION parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector @@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData): Iteration Start Node Data. """ - pass + type: NodeType = NodeType.ITERATION_START class IterationState(BaseIterationState): diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 6d26cbfce4..bc36f635e9 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( NodeExecutionType, NodeType, @@ -460,21 +461,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: IterationNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = IterationNodeData.model_validate(node_data) - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": typed_node_data.iterator_selector, + f"{node_id}.input_selector": node_data.iterator_selector, } iteration_node_ids = set() # Find all nodes that belong to this loop nodes = graph_config.get("nodes", []) for node in nodes: - node_data = node.get("data", {}) - if node_data.get("iteration_id") == node_id: + node_config_data = node.get("data", {}) + if node_config_data.get("iteration_id") == node_id: in_iteration_node_id = node.get("id") if in_iteration_node_id: iteration_node_ids.add(in_iteration_node_id) @@ -490,14 +488,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # Get node class from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING - node_type = NodeType(sub_node_config.get("data", {}).get("type")) + typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) + node_type = typed_sub_node_config["data"].type if node_type not in NODE_TYPE_CLASSES_MAPPING: continue - node_version = sub_node_config.get("data", {}).get("version", "1") + node_version = str(typed_sub_node_config["data"].version) node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=sub_node_config + graph_config=graph_config, config=typed_sub_node_config ) sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: diff --git a/api/dify_graph/nodes/knowledge_index/entities.py b/api/dify_graph/nodes/knowledge_index/entities.py index 493b5eadd8..d88ee8e3af 100644 --- a/api/dify_graph/nodes/knowledge_index/entities.py +++ b/api/dify_graph/nodes/knowledge_index/entities.py @@ -3,7 +3,8 @@ from typing import Literal, Union from pydantic import BaseModel from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class RerankingModelConfig(BaseModel): @@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData): Knowledge index Node Data. """ - type: str = "knowledge-index" + type: NodeType = NodeType.KNOWLEDGE_INDEX chunk_structure: str index_chunk_variable_selector: list[str] indexing_technique: str | None = None diff --git a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py index eeb4f3c229..3c4fe2344c 100644 --- a/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py +++ b/api/dify_graph/nodes/knowledge_index/knowledge_index_node.py @@ -2,6 +2,7 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType, NodeType, SystemVariableKey from dify_graph.node_events import NodeRunResult @@ -30,7 +31,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", index_processor: IndexProcessorProtocol, diff --git a/api/dify_graph/nodes/knowledge_retrieval/entities.py b/api/dify_graph/nodes/knowledge_retrieval/entities.py index c3059897c7..8f226b9785 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/entities.py +++ b/api/dify_graph/nodes/knowledge_retrieval/entities.py @@ -3,7 +3,8 @@ from typing import Literal from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig @@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): Knowledge retrieval Node Data. """ - type: str = "knowledge-retrieval" + type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL query_variable_selector: list[str] | None | str = None query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] diff --git a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py index c67e14ce17..61c9614340 100644 --- a/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, WorkflowNodeExecutionMetadataKey, @@ -49,7 +50,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", rag_retrieval: RAGRetrievalProtocol, @@ -301,15 +302,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: KnowledgeRetrievalNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type - # Create typed NodeData from dict - typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) - variable_mapping = {} - if typed_node_data.query_variable_selector: - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector - if typed_node_data.query_attachment_selector: - variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector + if node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = node_data.query_variable_selector + if node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector return variable_mapping diff --git a/api/dify_graph/nodes/list_operator/entities.py b/api/dify_graph/nodes/list_operator/entities.py index 0fdd85f210..a91cfab8de 100644 --- a/api/dify_graph/nodes/list_operator/entities.py +++ b/api/dify_graph/nodes/list_operator/entities.py @@ -3,7 +3,8 @@ from enum import StrEnum from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class FilterOperator(StrEnum): @@ -62,6 +63,7 @@ class ExtractConfig(BaseModel): class ListOperatorNodeData(BaseNodeData): + type: NodeType = NodeType.LIST_OPERATOR variable: Sequence[str] = Field(default_factory=list) filter_by: FilterBy order_by: OrderByConfig diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py index 707ed8ece0..71728aa227 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -4,8 +4,9 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode -from dify_graph.nodes.base import BaseNodeData from dify_graph.nodes.base.entities import VariableSelector @@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): + type: NodeType = NodeType.LLM model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 5e59c96cd6..b88ff404c0 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -21,6 +21,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.signature import sign_upload_file from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, SystemVariableKey, @@ -121,7 +122,7 @@ class LLMNode(Node[LLMNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, @@ -954,14 +955,11 @@ class LLMNode(Node[LLMNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: LLMNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type _ = graph_config # Explicitly mark as unused - # Create typed NodeData from dict - typed_node_data = LLMNodeData.model_validate(node_data) - - prompt_template = typed_node_data.prompt_template + prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: @@ -979,7 +977,7 @@ class LLMNode(Node[LLMNodeData]): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - memory = typed_node_data.memory + memory = node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template @@ -987,16 +985,16 @@ class LLMNode(Node[LLMNodeData]): for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - if typed_node_data.context.enabled: - variable_mapping["#context#"] = typed_node_data.context.variable_selector + if node_data.context.enabled: + variable_mapping["#context#"] = node_data.context.variable_selector - if typed_node_data.vision.enabled: - variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector + if node_data.vision.enabled: + variable_mapping["#files#"] = node_data.vision.configs.variable_selector - if typed_node_data.memory: + if node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - if typed_node_data.prompt_config: + if node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): @@ -1009,7 +1007,7 @@ class LLMNode(Node[LLMNodeData]): break if enable_jinja: - for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: + for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} diff --git a/api/dify_graph/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py index b4a8518048..8a3df5c234 100644 --- a/api/dify_graph/nodes/loop/entities.py +++ b/api/dify_graph/nodes/loop/entities.py @@ -3,7 +3,9 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator -from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState from dify_graph.utils.condition.entities import Condition from dify_graph.variables.types import SegmentType @@ -39,6 +41,7 @@ class LoopVariableData(BaseModel): class LoopNodeData(BaseLoopNodeData): + type: NodeType = NodeType.LOOP loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] @@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData): Loop Start Node Data. """ - pass + type: NodeType = NodeType.LOOP_START class LoopEndNodeData(BaseNodeData): @@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData): Loop End Node Data. """ - pass + type: NodeType = NodeType.LOOP_END class LoopState(BaseLoopState): diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 8279f0fc66..27b78eaa3f 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( NodeExecutionType, NodeType, @@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: LoopNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = LoopNodeData.model_validate(node_data) - variable_mapping = {} # Extract loop node IDs statically from graph_config @@ -320,14 +318,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # Get node class from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING - node_type = NodeType(sub_node_config.get("data", {}).get("type")) + typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) + node_type = typed_sub_node_config["data"].type if node_type not in NODE_TYPE_CLASSES_MAPPING: continue - node_version = sub_node_config.get("data", {}).get("version", "1") + node_version = str(typed_sub_node_config["data"].version) node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=sub_node_config + graph_config=graph_config, config=typed_sub_node_config ) sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: @@ -342,7 +341,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) - for loop_variable in typed_node_data.loop_variables or []: + for loop_variable in node_data.loop_variables or []: if loop_variable.value_type == "variable": assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py index 3b042710f9..8f8a278d5b 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -8,7 +8,8 @@ from pydantic import ( ) from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig from dify_graph.variables.types import SegmentType @@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData): Parameter Extractor Node Data. """ + type: NodeType = NodeType.PARAMETER_EXTRACTOR model: ModelConfig query: list[str] parameters: list[ParameterConfig] diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 1325a6a09a..68bd15db30 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -10,6 +10,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, WorkflowNodeExecutionMetadataKey, @@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: ParameterExtractorNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = ParameterExtractorNodeData.model_validate(node_data) + _ = graph_config # Explicitly mark as unused + variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} - - if typed_node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction) + if node_data.instruction: + selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) for selector in selectors: variable_mapping[selector.variable] = selector.value_selector diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py index 03e0a0ac53..77a6c70c28 100644 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ b/api/dify_graph/nodes/question_classifier/entities.py @@ -1,7 +1,8 @@ from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.llm import ModelConfig, VisionConfig @@ -11,6 +12,7 @@ class ClassConfig(BaseModel): class QuestionClassifierNodeData(BaseNodeData): + type: NodeType = NodeType.QUESTION_CLASSIFIER query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 443d216186..a61bca4ea9 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -7,6 +7,7 @@ from core.model_manager import ModelInstance from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeExecutionType, NodeType, @@ -62,7 +63,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -251,16 +252,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: QuestionClassifierNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type - # Create typed NodeData from dict - typed_node_data = QuestionClassifierNodeData.model_validate(node_data) - - variable_mapping = {"query": typed_node_data.query_variable_selector} + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors: list[VariableSelector] = [] - if typed_node_data.instruction: - variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) + if node_data.instruction: + variable_template_parser = VariableTemplateParser(template=node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) diff --git a/api/dify_graph/nodes/start/entities.py b/api/dify_graph/nodes/start/entities.py index 0df832740e..cbf7348360 100644 --- a/api/dify_graph/nodes/start/entities.py +++ b/api/dify_graph/nodes/start/entities.py @@ -2,7 +2,8 @@ from collections.abc import Sequence from pydantic import Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.variables.input_entities import VariableEntity @@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData): Start Node Data """ + type: NodeType = NodeType.START variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/dify_graph/nodes/template_transform/entities.py index 123fd41f81..2a79a82870 100644 --- a/api/dify_graph/nodes/template_transform/entities.py +++ b/api/dify_graph/nodes/template_transform/entities.py @@ -1,4 +1,5 @@ -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.base.entities import VariableSelector @@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData): Template Transform Node Data. """ + type: NodeType = NodeType.TEMPLATE_TRANSFORM variables: list[VariableSelector] template: str diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py index 367442e997..9dfb535342 100644 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ b/api/dify_graph/nodes/template_transform/template_transform_node.py @@ -1,6 +1,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -25,7 +26,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -86,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = TemplateTransformNodeData.model_validate(node_data) - return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in typed_node_data.variables + for variable_selector in node_data.variables } diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py index f15dabdeeb..4ba8c16e85 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -4,7 +4,8 @@ from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class ToolEntity(BaseModel): @@ -32,6 +33,8 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): + type: NodeType = NodeType.TOOL + class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index a6e0b710f1..44d0ca885d 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -7,6 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( NodeType, SystemVariableKey, @@ -46,7 +47,7 @@ class ToolNode(Node[ToolNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -484,7 +485,7 @@ class ToolNode(Node[ToolNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: ToolNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -493,9 +494,8 @@ class ToolNode(Node[ToolNodeData]): :param node_data: node data :return: """ - # Create typed NodeData from dict - typed_node_data = ToolNodeData.model_validate(node_data) - + _ = graph_config # Explicitly mark as unused + typed_node_data = node_data result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] diff --git a/api/dify_graph/nodes/trigger_plugin/entities.py b/api/dify_graph/nodes/trigger_plugin/entities.py index 75d10ecaa4..33a61c9bc8 100644 --- a/api/dify_graph/nodes/trigger_plugin/entities.py +++ b/api/dify_graph/nodes/trigger_plugin/entities.py @@ -4,13 +4,16 @@ from typing import Any, Literal, Union from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.entities.entities import EventParameter -from dify_graph.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.nodes.trigger_plugin.exc import TriggerEventParameterError class TriggerEventNodeData(BaseNodeData): """Plugin trigger node data""" + type: NodeType = NodeType.TRIGGER_PLUGIN + class TriggerEventInput(BaseModel): value: Union[Any, list[str]] type: Literal["mixed", "variable", "constant"] @@ -38,8 +41,6 @@ class TriggerEventNodeData(BaseNodeData): raise ValueError("value must be a string, int, float, bool or dict") return type - title: str - desc: str | None = None plugin_id: str = Field(..., description="Plugin ID") provider_id: str = Field(..., description="Provider ID") event_name: str = Field(..., description="Event name") diff --git a/api/dify_graph/nodes/trigger_schedule/entities.py b/api/dify_graph/nodes/trigger_schedule/entities.py index 6daadc7666..2b0edcabba 100644 --- a/api/dify_graph/nodes/trigger_schedule/entities.py +++ b/api/dify_graph/nodes/trigger_schedule/entities.py @@ -2,7 +2,8 @@ from typing import Literal, Union from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): @@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData): Trigger Schedule Node Data """ + type: NodeType = NodeType.TRIGGER_SCHEDULE mode: str = Field(default="visual", description="Schedule mode: visual or cron") frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly") cron_expression: str | None = Field(default=None, description="Cron expression for cron mode") diff --git a/api/dify_graph/nodes/trigger_schedule/exc.py b/api/dify_graph/nodes/trigger_schedule/exc.py index caea6241e4..336d64d58f 100644 --- a/api/dify_graph/nodes/trigger_schedule/exc.py +++ b/api/dify_graph/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/dify_graph/nodes/trigger_webhook/entities.py b/api/dify_graph/nodes/trigger_webhook/entities.py index fa36aeabd3..a4f8745e71 100644 --- a/api/dify_graph/nodes/trigger_webhook/entities.py +++ b/api/dify_graph/nodes/trigger_webhook/entities.py @@ -1,10 +1,41 @@ from collections.abc import Sequence from enum import StrEnum -from typing import Literal from pydantic import BaseModel, Field, field_validator -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.variables.types import SegmentType + +_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + } +) + +_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + } +) + +_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES + +_WEBHOOK_BODY_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_BOOLEAN, + SegmentType.ARRAY_OBJECT, + SegmentType.FILE, + } +) class Method(StrEnum): @@ -25,29 +56,34 @@ class ContentType(StrEnum): class WebhookParameter(BaseModel): - """Parameter definition for headers, query params, or body.""" + """Parameter definition for headers or query params.""" name: str + type: SegmentType = SegmentType.STRING required: bool = False + @field_validator("type", mode="after") + @classmethod + def validate_type(cls, v: SegmentType) -> SegmentType: + if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook parameter type: {v}") + return v + class WebhookBodyParameter(BaseModel): """Body parameter with type information.""" name: str - type: Literal[ - "string", - "number", - "boolean", - "object", - "array[string]", - "array[number]", - "array[boolean]", - "array[object]", - "file", - ] = "string" + type: SegmentType = SegmentType.STRING required: bool = False + @field_validator("type", mode="after") + @classmethod + def validate_type(cls, v: SegmentType) -> SegmentType: + if v not in _WEBHOOK_BODY_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook body parameter type: {v}") + return v + class WebhookData(BaseNodeData): """ @@ -57,6 +93,7 @@ class WebhookData(BaseNodeData): class SyncMode(StrEnum): SYNC = "async" # only support + type: NodeType = NodeType.TRIGGER_WEBHOOK method: Method = Method.GET content_type: ContentType = Field(default=ContentType.JSON) headers: Sequence[WebhookParameter] = Field(default_factory=list) @@ -71,6 +108,22 @@ class WebhookData(BaseNodeData): return v.lower() return v + @field_validator("headers", mode="after") + @classmethod + def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]: + for param in v: + if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook header parameter type: {param.type}") + return v + + @field_validator("params", mode="after") + @classmethod + def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]: + for param in v: + if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook query parameter type: {param.type}") + return v + status_code: int = 200 # Expected status code for response response_body: str = "" # Template for response body diff --git a/api/dify_graph/nodes/trigger_webhook/exc.py b/api/dify_graph/nodes/trigger_webhook/exc.py index 853b2456c5..4d87f2a069 100644 --- a/api/dify_graph/nodes/trigger_webhook/exc.py +++ b/api/dify_graph/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/dify_graph/nodes/trigger_webhook/node.py b/api/dify_graph/nodes/trigger_webhook/node.py index e466541908..413eda5272 100644 --- a/api/dify_graph/nodes/trigger_webhook/node.py +++ b/api/dify_graph/nodes/trigger_webhook/node.py @@ -152,7 +152,7 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = raw_data continue - if param_type == "file": + if param_type == SegmentType.FILE: # Get File object (already processed by webhook controller) files = webhook_data.get("files", {}) if files and isinstance(files, dict): diff --git a/api/dify_graph/nodes/variable_aggregator/entities.py b/api/dify_graph/nodes/variable_aggregator/entities.py index 5f7c1dbe93..fec4c4474c 100644 --- a/api/dify_graph/nodes/variable_aggregator/entities.py +++ b/api/dify_graph/nodes/variable_aggregator/entities.py @@ -1,6 +1,7 @@ from pydantic import BaseModel -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from dify_graph.variables.types import SegmentType @@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData): Variable Aggregator Node Data. """ + type: NodeType = NodeType.VARIABLE_AGGREGATOR output_type: str variables: list[list[str]] advanced_settings: AdvancedSettings | None = None diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py index 1aa7042b02..1d17b981ba 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ): @@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: VariableAssignerData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = VariableAssignerData.model_validate(node_data) - mapping = {} - assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] + assigned_variable_node_id = node_data.assigned_variable_selector[0] if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(typed_node_data.assigned_variable_selector) + selector_key = ".".join(node_data.assigned_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = typed_node_data.assigned_variable_selector + mapping[key] = node_data.assigned_variable_selector - selector_key = ".".join(typed_node_data.input_variable_selector) + selector_key = ".".join(node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = typed_node_data.input_variable_selector + mapping[key] = node_data.input_variable_selector return mapping def _run(self) -> NodeRunResult: diff --git a/api/dify_graph/nodes/variable_assigner/v1/node_data.py b/api/dify_graph/nodes/variable_assigner/v1/node_data.py index 11e8f93f35..a75a2397ba 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node_data.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node_data.py @@ -1,7 +1,8 @@ from collections.abc import Sequence from enum import StrEnum -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class WriteMode(StrEnum): @@ -11,6 +12,7 @@ class WriteMode(StrEnum): class VariableAssignerData(BaseNodeData): + type: NodeType = NodeType.VARIABLE_ASSIGNER assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/variable_assigner/v2/entities.py b/api/dify_graph/nodes/variable_assigner/v2/entities.py index 5f9211d600..ca3a94b777 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/entities.py +++ b/api/dify_graph/nodes/variable_assigner/v2/entities.py @@ -3,7 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from dify_graph.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType from .enums import InputType, Operation @@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel): class VariableAssignerNodeData(BaseNodeData): + type: NodeType = NodeType.VARIABLE_ASSIGNER version: str = "2" items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py index 7753382cd0..771609ceb6 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -56,7 +57,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ): @@ -94,13 +95,10 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: VariableAssignerNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = VariableAssignerNodeData.model_validate(node_data) - var_mapping: dict[str, Sequence[str]] = {} - for item in typed_node_data.items: + for item in node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping diff --git a/api/models/workflow.py b/api/models/workflow.py index d728ed83bc..21b899eeda 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -233,8 +233,11 @@ class Workflow(Base): # bug def get_node_config_by_id(self, node_id: str) -> NodeConfigDict: """Extract a node configuration from the workflow graph by node ID. - A node configuration is a dictionary containing the node's properties, including - the node's id, title, and its data as a dict. + + A node configuration includes the node id and a typed `BaseNodeData` for `data`. + `BaseNodeData` keeps a dict-like `get`/`__getitem__` compatibility layer backed by + model fields plus Pydantic extra storage for legacy consumers, but callers should + prefer attribute access. """ workflow_graph = self.graph_dict @@ -252,12 +255,9 @@ class Workflow(Base): # bug return NodeConfigDictAdapter.validate_python(node_config) @staticmethod - def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: + def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: """Extract type of a node from the node configuration returned by `get_node_config_by_id`.""" - node_config_data = node_config.get("data", {}) - # Get node class - node_type = NodeType(node_config_data.get("type")) - return node_type + return node_config["data"].type @staticmethod def get_enclosing_node_type_and_id( diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 8389ccbb34..88b640305d 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -1,14 +1,18 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.nodes import NodeType -from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig +from dify_graph.nodes.trigger_schedule.entities import ( + ScheduleConfig, + SchedulePlanUpdate, + TriggerScheduleNodeData, + VisualConfig, +) from dify_graph.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin @@ -176,26 +180,26 @@ class ScheduleService: return next_run_at @staticmethod - def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig: + def to_schedule_config(node_config: NodeConfigDict) -> ScheduleConfig: """ Converts user-friendly visual schedule settings to cron expression. Maintains consistency with frontend UI expectations while supporting croniter's extended syntax. """ - node_data = node_config.get("data", {}) - mode = node_data.get("mode", "visual") - timezone = node_data.get("timezone", "UTC") - node_id = node_config.get("id", "start") + node_data = TriggerScheduleNodeData.model_validate(node_config["data"], from_attributes=True) + mode = node_data.mode + timezone = node_data.timezone + node_id = node_config["id"] cron_expression = None if mode == "cron": - cron_expression = node_data.get("cron_expression") + cron_expression = node_data.cron_expression if not cron_expression: raise ScheduleConfigError("Cron expression is required for cron mode") elif mode == "visual": - frequency = str(node_data.get("frequency")) + frequency = str(node_data.frequency or "") if not frequency: raise ScheduleConfigError("Frequency is required for visual mode") - visual_config = VisualConfig(**node_data.get("visual_config", {})) + visual_config = VisualConfig.model_validate(node_data.visual_config or {}) cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config) if not cron_expression: raise ScheduleConfigError("Cron expression is required for visual mode") @@ -239,19 +243,21 @@ class ScheduleService: if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value: continue - mode = node_data.get("mode", "visual") - timezone = node_data.get("timezone", "UTC") node_id = node.get("id", "start") + trigger_data = TriggerScheduleNodeData.model_validate(node_data) + mode = trigger_data.mode + timezone = trigger_data.timezone cron_expression = None if mode == "cron": - cron_expression = node_data.get("cron_expression") + cron_expression = trigger_data.cron_expression if not cron_expression: raise ScheduleConfigError("Cron expression is required for cron mode") elif mode == "visual": - frequency = node_data.get("frequency") - visual_config_dict = node_data.get("visual_config", {}) - visual_config = VisualConfig(**visual_config_dict) + frequency = trigger_data.frequency + if not frequency: + raise ScheduleConfigError("Frequency is required for visual mode") + visual_config = VisualConfig.model_validate(trigger_data.visual_config or {}) cron_expression = ScheduleService.visual_to_cron(frequency, visual_config) else: raise ScheduleConfigError(f"Invalid schedule mode: {mode}") diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index f1f0d0ea84..2343bbbd3d 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -16,6 +16,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db @@ -41,7 +42,7 @@ class TriggerService: @classmethod def invoke_trigger_event( - cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent + cls, tenant_id: str, user_id: str, node_config: NodeConfigDict, event: PluginTriggerDebugEvent ) -> TriggerInvokeEventResponse: """Invoke a trigger event.""" subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id( @@ -50,7 +51,7 @@ class TriggerService: ) if not subscription: raise ValueError("Subscription not found") - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {})) + node_data = TriggerEventNodeData.model_validate(node_config["data"], from_attributes=True) request = TriggerHttpRequestCachingService.get_request(event.request_id) payload = TriggerHttpRequestCachingService.get_payload(event.request_id) # invoke triger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 285645edce..02977b934c 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -2,7 +2,7 @@ import json import logging import mimetypes import secrets -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any import orjson @@ -16,9 +16,16 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.tool_file_manager import ToolFileManager +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import NodeType from dify_graph.file.models import FileTransferMethod -from dify_graph.variables.types import SegmentType +from dify_graph.nodes.trigger_webhook.entities import ( + ContentType, + WebhookBodyParameter, + WebhookData, + WebhookParameter, +) +from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -57,7 +64,7 @@ class WebhookService: @classmethod def get_webhook_trigger_and_workflow( cls, webhook_id: str, is_debug: bool = False - ) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]: + ) -> tuple[WorkflowWebhookTrigger, Workflow, NodeConfigDict]: """Get webhook trigger, workflow, and node configuration. Args: @@ -135,7 +142,7 @@ class WebhookService: @classmethod def extract_and_validate_webhook_data( - cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any] + cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict ) -> dict[str, Any]: """Extract and validate webhook data in a single unified process. @@ -153,7 +160,7 @@ class WebhookService: raw_data = cls.extract_webhook_data(webhook_trigger) # Validate HTTP metadata (method, content-type) - node_data = node_config.get("data", {}) + node_data = WebhookData.model_validate(node_config["data"], from_attributes=True) validation_result = cls._validate_http_metadata(raw_data, node_data) if not validation_result["valid"]: raise ValueError(validation_result["error"]) @@ -192,7 +199,7 @@ class WebhookService: content_type = cls._extract_content_type(dict(request.headers)) # Route to appropriate extractor based on content type - extractors = { + extractors: dict[str, Callable[[], tuple[dict[str, Any], dict[str, Any]]]] = { "application/json": cls._extract_json_body, "application/x-www-form-urlencoded": cls._extract_form_body, "multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger), @@ -214,7 +221,7 @@ class WebhookService: return data @classmethod - def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: """Process and validate webhook data according to node configuration. Args: @@ -230,18 +237,13 @@ class WebhookService: result = raw_data.copy() # Validate and process headers - cls._validate_required_headers(raw_data["headers"], node_data.get("headers", [])) + cls._validate_required_headers(raw_data["headers"], node_data.headers) # Process query parameters with type conversion and validation - result["query_params"] = cls._process_parameters( - raw_data["query_params"], node_data.get("params", []), is_form_data=True - ) + result["query_params"] = cls._process_parameters(raw_data["query_params"], node_data.params, is_form_data=True) # Process body parameters based on content type - configured_content_type = node_data.get("content_type", "application/json").lower() - result["body"] = cls._process_body_parameters( - raw_data["body"], node_data.get("body", []), configured_content_type - ) + result["body"] = cls._process_body_parameters(raw_data["body"], node_data.body, node_data.content_type) return result @@ -424,7 +426,11 @@ class WebhookService: @classmethod def _process_parameters( - cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False + cls, + raw_params: dict[str, str], + param_configs: Sequence[WebhookParameter], + *, + is_form_data: bool = False, ) -> dict[str, Any]: """Process parameters with unified validation and type conversion. @@ -440,13 +446,13 @@ class WebhookService: ValueError: If required parameters are missing or validation fails """ processed = {} - configured_params = {config.get("name", ""): config for config in param_configs} + configured_params = {config.name: config for config in param_configs} # Process configured parameters for param_config in param_configs: - name = param_config.get("name", "") - param_type = param_config.get("type", SegmentType.STRING) - required = param_config.get("required", False) + name = param_config.name + param_type = param_config.type + required = param_config.required # Check required parameters if required and name not in raw_params: @@ -465,7 +471,10 @@ class WebhookService: @classmethod def _process_body_parameters( - cls, raw_body: dict[str, Any], body_configs: list, content_type: str + cls, + raw_body: dict[str, Any], + body_configs: Sequence[WebhookBodyParameter], + content_type: ContentType, ) -> dict[str, Any]: """Process body parameters based on content type and configuration. @@ -480,25 +489,28 @@ class WebhookService: Raises: ValueError: If required body parameters are missing or validation fails """ - if content_type in ["text/plain", "application/octet-stream"]: - # For text/plain and octet-stream, validate required content exists - if body_configs and any(config.get("required", False) for config in body_configs): - raw_content = raw_body.get("raw") - if not raw_content: - raise ValueError(f"Required body content missing for {content_type} request") - return raw_body + match content_type: + case ContentType.TEXT | ContentType.BINARY: + # For text/plain and octet-stream, validate required content exists + if body_configs and any(config.required for config in body_configs): + raw_content = raw_body.get("raw") + if not raw_content: + raise ValueError(f"Required body content missing for {content_type} request") + return raw_body + case _: + pass # For structured data (JSON, form-data, etc.) processed = {} - configured_params = {config.get("name", ""): config for config in body_configs} + configured_params: dict[str, WebhookBodyParameter] = {config.name: config for config in body_configs} for body_config in body_configs: - name = body_config.get("name", "") - param_type = body_config.get("type", SegmentType.STRING) - required = body_config.get("required", False) + name = body_config.name + param_type = body_config.type + required = body_config.required # Handle file parameters for multipart data - if param_type == SegmentType.FILE and content_type == "multipart/form-data": + if param_type == SegmentType.FILE and content_type == ContentType.FORM_DATA: # File validation is handled separately in extract phase continue @@ -508,7 +520,7 @@ class WebhookService: if name in raw_body: raw_value = raw_body[name] - is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"] + is_form_data = content_type in [ContentType.FORM_URLENCODED, ContentType.FORM_DATA] processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data) # Include unconfigured parameters @@ -519,7 +531,9 @@ class WebhookService: return processed @classmethod - def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any: + def _validate_and_convert_value( + cls, param_name: str, value: Any, param_type: SegmentType | str, is_form_data: bool + ) -> Any: """Unified validation and type conversion for parameter values. Args: @@ -532,7 +546,8 @@ class WebhookService: Any: The validated and converted value Raises: - ValueError: If validation or conversion fails + ValueError: If validation or conversion fails. The original validation + error is preserved as ``__cause__`` for debugging. """ try: if is_form_data: @@ -542,10 +557,10 @@ class WebhookService: # JSON data should already be in correct types, just validate return cls._validate_json_value(param_name, value, param_type) except Exception as e: - raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") + raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") from e @classmethod - def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any: + def _convert_form_value(cls, param_name: str, value: str, param_type: SegmentType | str) -> Any: """Convert form data string values to specified types. Args: @@ -576,7 +591,7 @@ class WebhookService: raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") @classmethod - def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any: + def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any: """Validate JSON values against expected types. Args: @@ -590,43 +605,43 @@ class WebhookService: Raises: ValueError: If the value type doesn't match the expected type """ - type_validators = { - SegmentType.STRING: (lambda v: isinstance(v, str), "string"), - SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"), - SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"), - SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"), - SegmentType.ARRAY_STRING: ( - lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v), - "array of strings", - ), - SegmentType.ARRAY_NUMBER: ( - lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v), - "array of numbers", - ), - SegmentType.ARRAY_BOOLEAN: ( - lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v), - "array of booleans", - ), - SegmentType.ARRAY_OBJECT: ( - lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v), - "array of objects", - ), - } - - validator_info = type_validators.get(SegmentType(param_type)) - if not validator_info: - logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name) + param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name) + if param_type_enum is None: return value - validator, expected_type = validator_info - if not validator(value): + if not param_type_enum.is_valid(value, array_validation=ArrayValidation.ALL): actual_type = type(value).__name__ + expected_type = cls._expected_type_label(param_type_enum) raise ValueError(f"Expected {expected_type}, got {actual_type}") return value @classmethod - def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None: + def _coerce_segment_type(cls, param_type: SegmentType | str, *, param_name: str) -> SegmentType | None: + if isinstance(param_type, SegmentType): + return param_type + try: + return SegmentType(param_type) + except Exception: + logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name) + return None + + @staticmethod + def _expected_type_label(param_type: SegmentType) -> str: + match param_type: + case SegmentType.ARRAY_STRING: + return "array of strings" + case SegmentType.ARRAY_NUMBER: + return "array of numbers" + case SegmentType.ARRAY_BOOLEAN: + return "array of booleans" + case SegmentType.ARRAY_OBJECT: + return "array of objects" + case _: + return param_type.value + + @classmethod + def _validate_required_headers(cls, headers: dict[str, Any], header_configs: Sequence[WebhookParameter]) -> None: """Validate required headers are present. Args: @@ -639,14 +654,14 @@ class WebhookService: headers_lower = {k.lower(): v for k, v in headers.items()} headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()} for header_config in header_configs: - if header_config.get("required", False): - header_name = header_config.get("name", "") + if header_config.required: + header_name = header_config.name sanitized_name = cls._sanitize_key(header_name).lower() if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized: raise ValueError(f"Required header missing: {header_name}") @classmethod - def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: """Validate HTTP method and content-type. Args: @@ -657,13 +672,13 @@ class WebhookService: dict[str, Any]: Validation result with 'valid' key and optional 'error' key """ # Validate HTTP method - configured_method = node_data.get("method", "get").upper() + configured_method = node_data.method.value.upper() request_method = webhook_data["method"].upper() if configured_method != request_method: return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}") # Validate Content-type - configured_content_type = node_data.get("content_type", "application/json").lower() + configured_content_type = node_data.content_type.value.lower() request_content_type = cls._extract_content_type(webhook_data["headers"]) if configured_content_type != request_content_type: @@ -788,7 +803,7 @@ class WebhookService: raise @classmethod - def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]: + def generate_webhook_response(cls, node_config: NodeConfigDict) -> tuple[dict[str, Any], int]: """Generate HTTP response based on node configuration. Args: @@ -797,11 +812,11 @@ class WebhookService: Returns: tuple[dict[str, Any], int]: Response data and HTTP status code """ - node_data = node_config.get("data", {}) + node_data = WebhookData.model_validate(node_config["data"], from_attributes=True) # Get configured status code and response body - status_code = node_data.get("status_code", 200) - response_body = node_data.get("response_body", "") + status_code = node_data.status_code + response_body = node_data.response_body # Parse response body as JSON if it's valid JSON, otherwise return as text try: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6d462b60b9..ec37fbeb6b 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -16,6 +16,7 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution +from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.errors import WorkflowNodeRunFailedError @@ -693,7 +694,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - node_data = node_config.get("data", {}) + node_data = node_config["data"] if node_type.is_start_node: with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) @@ -703,7 +704,7 @@ class WorkflowService: workflow=draft_workflow, ) if node_type is NodeType.START: - start_data = StartNodeData.model_validate(node_data) + start_data = StartNodeData.model_validate(node_data, from_attributes=True) user_inputs = _rebuild_file_for_user_inputs_in_start_node( tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs ) @@ -941,7 +942,7 @@ class WorkflowService: if node_type is not NodeType.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config.get("data", {})) + node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, @@ -1059,7 +1060,7 @@ class WorkflowService: *, workflow: Workflow, account: Account, - node_config: Mapping[str, Any], + node_config: NodeConfigDict, variable_pool: VariablePool, ) -> HumanInputNode: graph_init_params = GraphInitParams( @@ -1079,7 +1080,7 @@ class WorkflowService: start_at=time.perf_counter(), ) node = HumanInputNode( - id=node_config.get("id", str(uuid.uuid4())), + id=node_config["id"], config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -1092,7 +1093,7 @@ class WorkflowService: *, app_model: App, workflow: Workflow, - node_config: Mapping[str, Any], + node_config: NodeConfigDict, manual_inputs: Mapping[str, Any], ) -> VariablePool: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f691113511..347fa9c9ed 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -189,6 +189,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from dify_graph.enums import NodeType from dify_graph.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, @@ -209,6 +210,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): # Create node data with custom auth and empty api_key node_data = HttpRequestNodeData( + type=NodeType.HTTP_REQUEST, title="http", desc="", url="http://example.com", diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index f91e6efb10..970da98c55 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -173,7 +173,7 @@ class TestWebhookService: assert workflow.app_id == test_data["app"].id assert node_config is not None assert node_config["id"] == "webhook_node" - assert node_config["data"]["title"] == "Test Webhook" + assert node_config["data"].title == "Test Webhook" def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers): """Test webhook trigger not found scenario.""" diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index cf52980e57..d6933e2180 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -25,7 +25,8 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing # load dotenv file with pydantic-settings - config = DifyConfig() + # Disable `.env` loading to ensure test stability across environments + config = DifyConfig(_env_file=None) # constant values assert config.COMMIT_SHA == "" @@ -59,7 +60,8 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("DB_PORT", "5432") monkeypatch.setenv("DB_DATABASE", "dify") - config = DifyConfig() + # Disable `.env` loading to ensure test stability across environments + config = DifyConfig(_env_file=None) # Verify default timeout values assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10 @@ -86,7 +88,8 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*") monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/") - flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore + # Disable `.env` loading to ensure test stability across environments + flask_app.config.from_mapping(DifyConfig(_env_file=None).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index e019a4b977..4f67d9cb56 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -8,6 +8,8 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import NodeType, WorkflowNodeExecutionStatus @@ -22,7 +24,7 @@ from dify_graph.graph_events import ( NodeRunSucceededEvent, ) from dify_graph.node_events import NodeRunResult, PauseRequestedEvent -from dify_graph.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig +from dify_graph.nodes.base.entities import OutputVariableEntity from dify_graph.nodes.base.node import Node from dify_graph.nodes.end.entities import EndNodeData from dify_graph.nodes.start.entities import StartNodeData @@ -42,6 +44,7 @@ if "core.ops.ops_trace_manager" not in sys.modules: class _StubToolNodeData(BaseNodeData): + type: NodeType = NodeType.TOOL pause_on: bool = False @@ -88,16 +91,17 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): original_create_node = DifyNodeFactory.create_node - def _patched_create_node(self, node_config: dict[str, object]) -> Node: - node_data = node_config.get("data", {}) - if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value: + def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_data = typed_node_config["data"] + if node_data.type == NodeType.TOOL: return _StubToolNode( - id=str(node_config["id"]), - config=node_config, + id=str(typed_node_config["id"]), + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, ) - return original_create_node(self, node_config) + return original_create_node(self, typed_node_config) mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 2e0715e974..178e26118e 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -7,7 +7,9 @@ import pytest from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable from models.workflow import Workflow @@ -105,3 +107,57 @@ def test_run_uses_single_node_execution_branch( assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER assert entry_kwargs["variable_pool"] is variable_pool assert entry_kwargs["graph_runtime_state"] is graph_runtime_state + + +def test_single_node_run_validates_target_node_config(monkeypatch) -> None: + runner = WorkflowBasedAppRunner( + queue_manager=MagicMock(spec=AppQueueManager), + variable_loader=MagicMock(), + app_id="app", + ) + + workflow = MagicMock(spec=Workflow) + workflow.id = "workflow" + workflow.tenant_id = "tenant" + workflow.graph_dict = { + "nodes": [ + { + "id": "loop-node", + "data": { + "type": "loop", + "title": "Loop", + "loop_count": 1, + "break_conditions": [], + "logical_operator": "and", + }, + } + ], + "edges": [], + } + + _, _, graph_runtime_state = _make_graph_state() + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + with ( + patch("core.app.apps.workflow_app_runner.DifyNodeFactory"), + patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()), + patch("core.app.apps.workflow_app_runner.load_into_variable_pool"), + patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + ): + runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert seen_configs == [workflow.graph_dict["nodes"][0]] diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index b98d56147e..9e9fc2e9ec 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -7,10 +7,10 @@ from dataclasses import dataclass import pytest from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeType from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -183,3 +183,36 @@ def test_graph_validation_blocks_start_and_trigger_coexistence( Graph.init(graph_config=graph_config, node_factory=node_factory) assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues) + + +def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + { + "id": "start", + "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, + {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + { + "id": "note", + "type": "custom-note", + "data": { + "type": "", + "title": "", + "desc": "", + "text": "{}", + "theme": "blue", + }, + }, + ] + graph_config["edges"] = [ + {"source": "start", "target": "answer", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + + assert graph.root_node.id == "start" + assert "answer" in graph.nodes + assert "note" not in graph.nodes diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 61e0f12550..2b926d754c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -2,6 +2,7 @@ from __future__ import annotations +from dify_graph.entities.base_node_data import RetryConfig from dify_graph.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine.domain.graph_execution import GraphExecution @@ -12,7 +13,6 @@ from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.entities import RetryConfig from dify_graph.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index b9ae680f52..4e13177d2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,6 +10,7 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st +from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType from dify_graph.enums import ErrorStrategy from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel @@ -18,7 +19,6 @@ from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.nodes.base.entities import DefaultValue, DefaultValueType # Import the test framework from the new module from .test_mock_config import MockConfigBuilder diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 9f33a81985..338db9076e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -5,10 +5,10 @@ This module provides a MockNodeFactory that automatically detects and mocks node requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). """ -from collections.abc import Mapping from typing import TYPE_CHECKING, Any from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.base.node import Node @@ -75,39 +75,27 @@ class MockNodeFactory(DifyNodeFactory): NodeType.CODE: MockCodeNode, } - def create_node(self, node_config: Mapping[str, Any]) -> Node: + def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ Create a node instance, using mock implementations for third-party service nodes. :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - # Get node type from config - node_data = node_config.get("data", {}) - node_type_str = node_data.get("type") - - if not node_type_str: - # Fall back to parent implementation for nodes without type - return super().create_node(node_config) - - try: - node_type = NodeType(node_type_str) - except ValueError: - # Unknown node type, use parent implementation - return super().create_node(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_data = typed_node_config["data"] + node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = node_config.get("id") - if not node_id: - raise ValueError("Node config missing id") + node_id = typed_node_config["id"] # Create mock node instance mock_class = self._mock_node_types[node_type] if node_type == NodeType.CODE: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -117,7 +105,7 @@ class MockNodeFactory(DifyNodeFactory): elif node_type == NodeType.HTTP_REQUEST: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -129,7 +117,7 @@ class MockNodeFactory(DifyNodeFactory): elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -139,7 +127,7 @@ class MockNodeFactory(DifyNodeFactory): else: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -148,7 +136,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(node_config) + return super().create_node(typed_node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index bf814d0c97..3fb775f934 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,7 +1,7 @@ import pytest +from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node # Ensures that all node classes are imported. @@ -126,3 +126,20 @@ def test_init_subclass_sets_node_data_type_from_generic(): return "1" assert _AutoNode._node_data_type is _TestNodeData + + +def test_validate_node_data_uses_declared_node_data_type(): + """Public validation should hydrate the subclass-declared node data model.""" + + class _AutoNode(Node[_TestNodeData]): + node_type = NodeType.CODE + + @staticmethod + def version() -> str: + return "1" + + base_node_data = BaseNodeData.model_validate({"type": NodeType.CODE, "title": "Test"}) + + validated = _AutoNode.validate_node_data(base_node_data) + + assert isinstance(validated, _TestNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index f8d799e446..86d326aead 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,8 +1,8 @@ import types from collections.abc import Mapping +from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import NodeType -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 95cb653635..784e08edd2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -272,7 +272,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_1", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert result == {} @@ -292,7 +292,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_1", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert "node_1.input_text" in result @@ -315,7 +315,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="code_node", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert len(result) == 3 @@ -338,7 +338,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_x", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"] @@ -437,7 +437,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node._node_data = node._hydrate_node_data(data) + node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True) assert node._node_data.title == "Test Node" assert node._node_data.code_language == CodeLanguage.PYTHON3 @@ -453,7 +453,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node._node_data = node._hydrate_node_data(data) + node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True) assert node._node_data.code_language == CodeLanguage.JAVASCRIPT diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index b95a7ad8ae..490df52533 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -1,3 +1,4 @@ +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from dify_graph.nodes.iteration.exc import ( @@ -388,3 +389,50 @@ class TestIterationNodeErrorStrategies: result = node._get_default_value_dict() assert isinstance(result, dict) + + +def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + child_node_config = { + "id": "answer-node", + "data": { + "type": "answer", + "title": "Answer", + "answer": "", + "iteration_id": "iteration-node", + }, + } + + IterationNode._extract_variable_selector_to_variable_mapping( + graph_config={ + "nodes": [ + { + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration", "result"], + }, + }, + child_node_config, + ], + "edges": [], + }, + node_id="iteration-node", + node_data=IterationNodeData( + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration", "result"], + ), + ) + + assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index e929d652fd..b7a7a9c938 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -410,14 +410,14 @@ class TestKnowledgeRetrievalNode: """Test _extract_variable_selector_to_variable_mapping class method.""" # Arrange node_id = "knowledge_node_1" - node_data = { - "type": "knowledge-retrieval", - "title": "Knowledge Retrieval", - "dataset_ids": [str(uuid.uuid4())], - "retrieval_mode": "multiple", - "query_variable_selector": ["start", "query"], - "query_attachment_selector": ["start", "attachments"], - } + node_data = KnowledgeRetrievalNodeData( + type="knowledge-retrieval", + title="Knowledge Retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + query_variable_selector=["start", "query"], + query_attachment_selector=["start", "attachments"], + ) graph_config = {} # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 44abf430c0..0d81e7762b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -4,8 +4,9 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import NodeType -from dify_graph.nodes.base.entities import BaseNodeData from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool from dify_graph.system_variable import SystemVariable @@ -40,13 +41,26 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state +def _build_node_config() -> NodeConfigDict: + return NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": NodeType.ANSWER.value, + "title": "Sample", + "foo": "bar", + }, + } + ) + + def test_node_hydrates_data_during_initialization(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) node = _SampleNode( id="node-1", - config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + config=_build_node_config(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -72,7 +86,7 @@ def test_node_accepts_invoke_from_enum(): node = _SampleNode( id="node-1", - config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + config=_build_node_config(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -99,3 +113,17 @@ def test_missing_generic_argument_raises_type_error(): def _run(self): raise NotImplementedError + + +def test_base_node_data_keeps_dict_style_access_compatibility(): + node_data = _SampleNodeData.model_validate( + { + "type": NodeType.ANSWER.value, + "title": "Sample", + "foo": "bar", + } + ) + + assert node_data["foo"] == "bar" + assert node_data.get("foo") == "bar" + assert node_data.get("missing", "fallback") == "fallback" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py new file mode 100644 index 0000000000..6372583839 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py @@ -0,0 +1,52 @@ +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.nodes.loop.entities import LoopNodeData +from dify_graph.nodes.loop.loop_node import LoopNode + + +def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + child_node_config = { + "id": "answer-node", + "data": { + "type": "answer", + "title": "Answer", + "answer": "", + "loop_id": "loop-node", + }, + } + + LoopNode._extract_variable_selector_to_variable_mapping( + graph_config={ + "nodes": [ + { + "id": "loop-node", + "data": { + "type": "loop", + "title": "Loop", + "loop_count": 1, + "break_conditions": [], + "logical_operator": "and", + }, + }, + child_node_config, + ], + "edges": [], + }, + node_id="loop-node", + node_data=LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + ), + ) + + assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 410c4993e4..61b18566b0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -210,9 +210,6 @@ def test_webhook_data_model_dump_with_alias(): def test_webhook_data_validation_errors(): """Test WebhookData validation errors.""" - # Title is required (inherited from BaseNodeData) - with pytest.raises(ValidationError): - WebhookData() # Invalid method with pytest.raises(ValidationError): @@ -254,6 +251,36 @@ def test_webhook_data_sequence_fields(): assert len(data.headers) == 1 # Should still be 1 +def test_webhook_data_rejects_non_string_header_types(): + """Headers should stay string-only because runtime does not coerce header values.""" + for param_type in ["number", "boolean", "object", "array[string]", "file"]: + with pytest.raises(ValidationError): + WebhookData( + title="Test", + headers=[WebhookParameter(name="X-Test", type=param_type)], + ) + + +def test_webhook_data_limits_query_param_types_to_scalar_values(): + """Query params only support scalar conversions in the current runtime.""" + data = WebhookData( + title="Test", + params=[ + WebhookParameter(name="count", type="number"), + WebhookParameter(name="enabled", type="boolean"), + ], + ) + assert data.params[0].type == "number" + assert data.params[1].type == "boolean" + + for param_type in ["object", "array[string]", "array[number]", "array[boolean]", "array[object]", "file"]: + with pytest.raises(ValidationError): + WebhookData( + title="Test", + params=[WebhookParameter(name="test", type=param_type)], + ) + + def test_webhook_data_sync_mode(): """Test WebhookData SyncMode nested enum.""" # Test that SyncMode enum exists and has expected value @@ -297,7 +324,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from dify_graph.nodes.base import BaseNodeData + from dify_graph.entities.base_node_data import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index f2273e441e..a821e361c5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,6 +1,6 @@ import pytest -from dify_graph.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError from dify_graph.nodes.trigger_webhook.exc import ( WebhookConfigError, WebhookNodeError, diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py new file mode 100644 index 0000000000..22be656d4b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Any + +from core.model_manager import ModelInstance +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.nodes.llm.entities import LLMNodeData +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +def _build_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: + graph_init_params = build_test_graph_init_params( + workflow_id="workflow", + graph_config=graph_config, + tenant_id="tenant", + app_id="app", + user_id="user", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool( + system_variables=SystemVariable.default(), + user_inputs={}, + environment_variables=[], + ), + start_at=0.0, + ) + return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + + +def test_create_node_uses_declared_node_data_type_for_llm_validation(monkeypatch): + class _FactoryLLMNodeData(LLMNodeData): + pass + + llm_node_config = { + "id": "llm-node", + "data": { + "type": "llm", + "title": "LLM", + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {}, + }, + "prompt_template": [], + "context": { + "enabled": False, + }, + }, + } + graph_config = {"nodes": [llm_node_config], "edges": []} + factory = _build_factory(graph_config) + captured: dict[str, object] = {} + + monkeypatch.setattr(LLMNode, "_node_data_type", _FactoryLLMNodeData) + + def _capture_model_instance(self: DifyNodeFactory, node_data: object) -> ModelInstance: + captured["node_data"] = node_data + return object() # type: ignore[return-value] + + def _capture_memory( + self: DifyNodeFactory, + *, + node_data: object, + model_instance: ModelInstance, + ) -> None: + captured["memory_node_data"] = node_data + + monkeypatch.setattr(DifyNodeFactory, "_build_model_instance_for_llm_node", _capture_model_instance) + monkeypatch.setattr(DifyNodeFactory, "_build_memory_for_llm_node", _capture_memory) + + node = factory.create_node(llm_node_config) + + assert isinstance(captured["node_data"], _FactoryLLMNodeData) + assert isinstance(captured["memory_node_data"], _FactoryLLMNodeData) + assert isinstance(node.node_data, _FactoryLLMNodeData) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 0aa6ec3f45..93ba7f3333 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -9,6 +9,7 @@ from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.file.enums import FileType from dify_graph.file.models import File, FileTransferMethod from dify_graph.nodes.code.code_node import CodeNode @@ -124,7 +125,7 @@ class TestWorkflowEntry: def get_node_config_by_id(self, target_id: str): assert target_id == node_id - return node_config + return NodeConfigDictAdapter.validate_python(node_config) workflow = StubWorkflow() variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={}) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index 5d6fa4c137..fcdd1c2368 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock import pytest from sqlalchemy.orm import sessionmaker +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, @@ -22,7 +23,7 @@ def _make_service() -> WorkflowService: return WorkflowService(session_maker=sessionmaker()) -def _build_node_config(delivery_methods): +def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict: node_data = HumanInputNodeData( title="Human Input", delivery_methods=delivery_methods, @@ -31,7 +32,7 @@ def _build_node_config(delivery_methods): user_actions=[], ).model_dump(mode="json") node_data["type"] = NodeType.HUMAN_INPUT.value - return {"id": "node-1", "data": node_data} + return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod: diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 83c1f8d9da..e428c603a4 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock import pytest +from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import NodeType from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction from dify_graph.nodes.human_input.enums import FormInputType @@ -187,7 +188,10 @@ class TestWorkflowService: service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + node_config = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + ) + workflow.get_node_config_by_id.return_value = node_config workflow.get_enclosing_node_type_and_id.return_value = None service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -232,7 +236,7 @@ class TestWorkflowService: service._build_human_input_variable_pool.assert_called_once_with( app_model=app_model, workflow=workflow, - node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}, + node_config=node_config, manual_inputs={"#node-0.result#": "LLM output"}, ) @@ -267,7 +271,9 @@ class TestWorkflowService: service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + ) service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")