refactor(api): type single node workflow helpers

This commit is contained in:
Yanli 盐粒 2026-03-17 20:16:14 +08:00
parent f874ca183e
commit 1dce81c604
4 changed files with 167 additions and 75 deletions

View File

@ -1,13 +1,17 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Protocol, TypeAlias
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.app_invoke_entities import (
InvokeFrom,
UserFrom,
build_dify_run_context,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
GraphConfigObject: TypeAlias = dict[str, object]
GraphConfigMapping: TypeAlias = Mapping[str, object]
class SingleNodeRunEntity(Protocol):
node_id: str
inputs: Mapping[str, object]
class WorkflowBasedAppRunner:
def __init__(
@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_config: GraphConfigMapping,
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
single_iteration_run: SingleNodeRunEntity | None = None,
single_loop_run: SingleNodeRunEntity | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
@staticmethod
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
nodes = graph_config.get("nodes")
edges = graph_config.get("edges")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(edges, list):
raise ValueError("edges in workflow graph must be a list")
validated_nodes: list[GraphConfigMapping] = []
for node in nodes:
if not isinstance(node, Mapping):
raise ValueError("nodes in workflow graph must be mappings")
validated_nodes.append(node)
validated_edges: list[GraphConfigMapping] = []
for edge in edges:
if not isinstance(edge, Mapping):
raise ValueError("edges in workflow graph must be mappings")
validated_edges.append(edge)
return validated_nodes, validated_edges
@staticmethod
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
if node_config is None:
return None
node_data = node_config.get("data")
if not isinstance(node_data, Mapping):
return None
start_node_id = node_data.get("start_node_id")
return start_node_id if isinstance(start_node_id, str) else None
@classmethod
def _build_single_node_graph_config(
cls,
*,
graph_config: GraphConfigMapping,
node_id: str,
node_type_filter_key: str,
) -> tuple[GraphConfigObject, NodeConfigDict]:
node_configs, edge_configs = cls._get_graph_items(graph_config)
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
start_node_id = cls._extract_start_node_id(main_node_config)
filtered_node_configs = [
dict(node)
for node in node_configs
if node.get("id") == node_id
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
or (start_node_id and node.get("id") == start_node_id)
]
if not filtered_node_configs:
raise ValueError(f"node id {node_id} not found in workflow graph")
filtered_node_ids = {
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
}
filtered_edge_configs = [
dict(edge)
for edge in edge_configs
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
]
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
if target_node_config is None:
raise ValueError(f"node id {node_id} not found in workflow graph")
return (
{
"nodes": filtered_node_configs,
"edges": filtered_edge_configs,
},
NodeConfigDictAdapter.validate_python(target_node_config),
)
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
user_inputs: Mapping[str, object],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
graph_config, target_node_config = self._build_single_node_graph_config(
graph_config=graph_config,
node_id=node_id,
node_type_filter_key=node_type_filter_key,
)
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)

View File

@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None
@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None

View File

@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, TypeAlias, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@ -32,6 +32,13 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
SpecialValueScalar: TypeAlias = str | int | float | bool | None
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
SerializedSpecialValue: TypeAlias = (
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
)
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
class _WorkflowChildEngineBuilder:
@staticmethod
@ -276,10 +283,10 @@ class WorkflowEntry:
@staticmethod
def _create_single_node_graph(
node_id: str,
node_data: dict[str, Any],
node_data: Mapping[str, object],
node_width: int = 114,
node_height: int = 514,
) -> dict[str, Any]:
) -> SingleNodeGraphConfig:
"""
Create a minimal graph structure for testing a single node in isolation.
@ -289,14 +296,14 @@ class WorkflowEntry:
:param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node
"""
node_config = {
node_config: dict[str, object] = {
"id": node_id,
"width": node_width,
"height": node_height,
"type": "custom",
"data": node_data,
"data": dict(node_data),
}
start_node_config = {
start_node_config: dict[str, object] = {
"id": "start",
"width": node_width,
"height": node_height,
@ -321,7 +328,12 @@ class WorkflowEntry:
@classmethod
def run_free_node(
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
cls,
node_data: Mapping[str, object],
node_id: str,
tenant_id: str,
user_id: str,
user_inputs: Mapping[str, object],
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Run free node
@ -339,6 +351,8 @@ class WorkflowEntry:
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "")
if not isinstance(node_type, str):
raise ValueError("Node type must be a string")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
@ -369,7 +383,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@ -405,30 +419,34 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
if result is None:
return None
if isinstance(result, dict):
return result
raise TypeError("handle_special_values expects a mapping input")
@staticmethod
def _handle_special_values(value: Any):
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
if value is None:
return value
if isinstance(value, dict):
res = {}
if isinstance(value, Mapping):
res: dict[str, SerializedSpecialValue] = {}
for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res_list = []
res_list: list[SerializedSpecialValue] = []
for item in value:
res_list.append(WorkflowEntry._handle_special_values(item))
return res_list
if isinstance(value, File):
return value.to_dict()
return dict(value.to_dict())
return value
@classmethod

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any
from copy import deepcopy
from unittest.mock import MagicMock, patch
import pytest
@ -33,8 +33,8 @@ def _make_graph_state():
],
)
def test_run_uses_single_node_execution_branch(
single_iteration_run: Any,
single_loop_run: Any,
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
) -> None:
app_config = MagicMock()
app_config.app_id = "app"
@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
"break_conditions": [],
"logical_operator": "and",
},
},
{
"id": "other-node",
"data": {
"type": "answer",
"title": "Answer",
},
},
],
"edges": [
{
"source": "other-node",
"target": "loop-node",
}
],
"edges": [],
}
original_graph_dict = deepcopy(workflow.graph_dict)
_, _, graph_runtime_state = _make_graph_state()
seen_configs: list[object] = []
@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
seen_configs.append(value)
return original_validate_python(value)
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
with (
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
):
runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
)
assert seen_configs == [workflow.graph_dict["nodes"][0]]
assert workflow.graph_dict == original_graph_dict
graph_config = graph_init.call_args.kwargs["graph_config"]
assert graph_config is not workflow.graph_dict
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
assert graph_config["edges"] == []