diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 25d3c8bd2a..d72a7a381a 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,13 +1,17 @@ import logging import time from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Protocol, TypeAlias from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.agent_strategy import AgentStrategyInfo -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.entities.app_invoke_entities import ( + InvokeFrom, + UserFrom, + build_dify_run_context, +) from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer @@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task logger = logging.getLogger(__name__) +GraphConfigObject: TypeAlias = dict[str, object] +GraphConfigMapping: TypeAlias = Mapping[str, object] + + +class SingleNodeRunEntity(Protocol): + node_id: str + inputs: Mapping[str, object] + class WorkflowBasedAppRunner: def __init__( @@ -98,7 +110,7 @@ class WorkflowBasedAppRunner: def _init_graph( self, - graph_config: Mapping[str, Any], + graph_config: GraphConfigMapping, graph_runtime_state: GraphRuntimeState, user_from: UserFrom, invoke_from: InvokeFrom, @@ -154,8 +166,8 @@ class WorkflowBasedAppRunner: def _prepare_single_node_execution( self, workflow: Workflow, - single_iteration_run: Any | None = None, - single_loop_run: Any | None = None, + single_iteration_run: SingleNodeRunEntity | None = None, + single_loop_run: SingleNodeRunEntity | None = None, ) -> tuple[Graph, VariablePool, GraphRuntimeState]: """ Prepare graph, variable pool, and runtime state for single node execution @@ -208,11 +220,88 @@ class WorkflowBasedAppRunner: # This ensures all nodes in the graph reference the same GraphRuntimeState instance return graph, variable_pool, graph_runtime_state + @staticmethod + def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]: + nodes = graph_config.get("nodes") + edges = graph_config.get("edges") + if not isinstance(nodes, list): + raise ValueError("nodes in workflow graph must be a list") + if not isinstance(edges, list): + raise ValueError("edges in workflow graph must be a list") + + validated_nodes: list[GraphConfigMapping] = [] + for node in nodes: + if not isinstance(node, Mapping): + raise ValueError("nodes in workflow graph must be mappings") + validated_nodes.append(node) + + validated_edges: list[GraphConfigMapping] = [] + for edge in edges: + if not isinstance(edge, Mapping): + raise ValueError("edges in workflow graph must be mappings") + validated_edges.append(edge) + + return validated_nodes, validated_edges + + @staticmethod + def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None: + if node_config is None: + return None + node_data = node_config.get("data") + if not isinstance(node_data, Mapping): + return None + start_node_id = node_data.get("start_node_id") + return start_node_id if isinstance(start_node_id, str) else None + + @classmethod + def _build_single_node_graph_config( + cls, + *, + graph_config: GraphConfigMapping, + node_id: str, + node_type_filter_key: str, + ) -> tuple[GraphConfigObject, NodeConfigDict]: + node_configs, edge_configs = cls._get_graph_items(graph_config) + main_node_config = next((node for node in node_configs if node.get("id") == node_id), None) + start_node_id = cls._extract_start_node_id(main_node_config) + + filtered_node_configs = [ + dict(node) + for node in node_configs + if node.get("id") == node_id + or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id) + or (start_node_id and node.get("id") == start_node_id) + ] + if not filtered_node_configs: + raise ValueError(f"node id {node_id} not found in workflow graph") + + filtered_node_ids = { + str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str) + } + filtered_edge_configs = [ + dict(edge) + for edge in edge_configs + if (edge.get("source") is None or edge.get("source") in filtered_node_ids) + and (edge.get("target") is None or edge.get("target") in filtered_node_ids) + ] + + target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None) + if target_node_config is None: + raise ValueError(f"node id {node_id} not found in workflow graph") + + return ( + { + "nodes": filtered_node_configs, + "edges": filtered_edge_configs, + }, + NodeConfigDictAdapter.validate_python(target_node_config), + ) + def _get_graph_and_variable_pool_for_single_node_run( self, workflow: Workflow, node_id: str, - user_inputs: dict[str, Any], + user_inputs: Mapping[str, object], graph_runtime_state: GraphRuntimeState, node_type_filter_key: str, # 'iteration_id' or 'loop_id' node_type_label: str = "node", # 'iteration' or 'loop' for error messages @@ -236,41 +325,14 @@ class WorkflowBasedAppRunner: if not graph_config: raise ValueError("workflow graph not found") - graph_config = cast(dict[str, Any], graph_config) - if "nodes" not in graph_config or "edges" not in graph_config: raise ValueError("nodes or edges not found in workflow graph") - if not isinstance(graph_config.get("nodes"), list): - raise ValueError("nodes in workflow graph must be a list") - - if not isinstance(graph_config.get("edges"), list): - raise ValueError("edges in workflow graph must be a list") - - # filter nodes only in the specified node type (iteration or loop) - main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None) - start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None - node_configs = [ - node - for node in graph_config.get("nodes", []) - if node.get("id") == node_id - or node.get("data", {}).get(node_type_filter_key, "") == node_id - or (start_node_id and node.get("id") == start_node_id) - ] - - graph_config["nodes"] = node_configs - - node_ids = [node.get("id") for node in node_configs] - - # filter edges only in the specified node type - edge_configs = [ - edge - for edge in graph_config.get("edges", []) - if (edge.get("source") is None or edge.get("source") in node_ids) - and (edge.get("target") is None or edge.get("target") in node_ids) - ] - - graph_config["edges"] = edge_configs + graph_config, target_node_config = self._build_single_node_graph_config( + graph_config=graph_config, + node_id=node_id, + node_type_filter_key=node_type_filter_key, + ) # Create required parameters for Graph.init graph_init_params = GraphInitParams( @@ -299,18 +361,6 @@ class WorkflowBasedAppRunner: if not graph: raise ValueError("graph not found in workflow") - # fetch node config from node id - target_node_config = None - for node in node_configs: - if node.get("id") == node_id: - target_node_config = node - break - - if not target_node_config: - raise ValueError(f"{node_type_label} node id not found in workflow graph") - - target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) - # Get node class node_type = target_node_config["data"].type node_version = str(target_node_config["data"].version) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ecbb1cf2f3..ad114d63aa 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ node_id: str - inputs: Mapping + inputs: Mapping[str, object] single_iteration_run: SingleIterationRunEntity | None = None @@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ node_id: str - inputs: Mapping + inputs: Mapping[str, object] single_loop_run: SingleLoopRunEntity | None = None @@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ node_id: str - inputs: dict + inputs: Mapping[str, object] single_iteration_run: SingleIterationRunEntity | None = None @@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ node_id: str - inputs: dict + inputs: Mapping[str, object] single_loop_run: SingleLoopRunEntity | None = None diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2e51a06bab..70359bf21c 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,7 +1,7 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any, TypeAlias, cast from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError @@ -32,6 +32,13 @@ from models.workflow import Workflow logger = logging.getLogger(__name__) +SpecialValueScalar: TypeAlias = str | int | float | bool | None +SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"] +SerializedSpecialValue: TypeAlias = ( + SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"] +) +SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]] + class _WorkflowChildEngineBuilder: @staticmethod @@ -276,10 +283,10 @@ class WorkflowEntry: @staticmethod def _create_single_node_graph( node_id: str, - node_data: dict[str, Any], + node_data: Mapping[str, object], node_width: int = 114, node_height: int = 514, - ) -> dict[str, Any]: + ) -> SingleNodeGraphConfig: """ Create a minimal graph structure for testing a single node in isolation. @@ -289,14 +296,14 @@ class WorkflowEntry: :param node_height: height for UI layout (default: 100) :return: graph dictionary with start node and target node """ - node_config = { + node_config: dict[str, object] = { "id": node_id, "width": node_width, "height": node_height, "type": "custom", - "data": node_data, + "data": dict(node_data), } - start_node_config = { + start_node_config: dict[str, object] = { "id": "start", "width": node_width, "height": node_height, @@ -321,7 +328,12 @@ class WorkflowEntry: @classmethod def run_free_node( - cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] + cls, + node_data: Mapping[str, object], + node_id: str, + tenant_id: str, + user_id: str, + user_inputs: Mapping[str, object], ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: """ Run free node @@ -339,6 +351,8 @@ class WorkflowEntry: graph_dict = cls._create_single_node_graph(node_id, node_data) node_type = node_data.get("type", "") + if not isinstance(node_type, str): + raise ValueError("Node type must be a string") if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}: raise ValueError(f"Node type {node_type} not supported") @@ -369,7 +383,7 @@ class WorkflowEntry: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) + node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)}) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -405,30 +419,34 @@ class WorkflowEntry: raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod - def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: + def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None: # NOTE(QuantumGhost): Avoid using this function in new code. # Keep values structured as long as possible and only convert to dict # immediately before serialization (e.g., JSON serialization) to maintain # data integrity and type information. result = WorkflowEntry._handle_special_values(value) - return result if isinstance(result, Mapping) or result is None else dict(result) + if result is None: + return None + if isinstance(result, dict): + return result + raise TypeError("handle_special_values expects a mapping input") @staticmethod - def _handle_special_values(value: Any): + def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue: if value is None: return value - if isinstance(value, dict): - res = {} + if isinstance(value, Mapping): + res: dict[str, SerializedSpecialValue] = {} for k, v in value.items(): res[k] = WorkflowEntry._handle_special_values(v) return res if isinstance(value, list): - res_list = [] + res_list: list[SerializedSpecialValue] = [] for item in value: res_list.append(WorkflowEntry._handle_special_values(item)) return res_list if isinstance(value, File): - return value.to_dict() + return dict(value.to_dict()) return value @classmethod diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 178e26118e..38416d57a1 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from copy import deepcopy from unittest.mock import MagicMock, patch import pytest @@ -33,8 +33,8 @@ def _make_graph_state(): ], ) def test_run_uses_single_node_execution_branch( - single_iteration_run: Any, - single_loop_run: Any, + single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None, + single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None, ) -> None: app_config = MagicMock() app_config.app_id = "app" @@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: "break_conditions": [], "logical_operator": "and", }, + }, + { + "id": "other-node", + "data": { + "type": "answer", + "title": "Answer", + }, + }, + ], + "edges": [ + { + "source": "other-node", + "target": "loop-node", } ], - "edges": [], } + original_graph_dict = deepcopy(workflow.graph_dict) _, _, graph_runtime_state = _make_graph_state() seen_configs: list[object] = [] @@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: seen_configs.append(value) return original_validate_python(value) + class FakeNodeClass: + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {} + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) with ( patch("core.app.apps.workflow_app_runner.DifyNodeFactory"), - patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()), + patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init, patch("core.app.apps.workflow_app_runner.load_into_variable_pool"), patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass), ): runner._get_graph_and_variable_pool_for_single_node_run( workflow=workflow, @@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: ) assert seen_configs == [workflow.graph_dict["nodes"][0]] + assert workflow.graph_dict == original_graph_dict + graph_config = graph_init.call_args.kwargs["graph_config"] + assert graph_config is not workflow.graph_dict + assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]] + assert graph_config["edges"] == []