mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 06:37:13 +08:00
refactor(api): type single node workflow helpers
This commit is contained in:
parent
f874ca183e
commit
1dce81c604
@ -1,13 +1,17 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Protocol, TypeAlias
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
UserFrom,
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GraphConfigObject: TypeAlias = dict[str, object]
|
||||
GraphConfigMapping: TypeAlias = Mapping[str, object]
|
||||
|
||||
|
||||
class SingleNodeRunEntity(Protocol):
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
def __init__(
|
||||
@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: GraphConfigMapping,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
single_iteration_run: SingleNodeRunEntity | None = None,
|
||||
single_loop_run: SingleNodeRunEntity | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
@staticmethod
|
||||
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
|
||||
nodes = graph_config.get("nodes")
|
||||
edges = graph_config.get("edges")
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
if not isinstance(edges, list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
validated_nodes: list[GraphConfigMapping] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
raise ValueError("nodes in workflow graph must be mappings")
|
||||
validated_nodes.append(node)
|
||||
|
||||
validated_edges: list[GraphConfigMapping] = []
|
||||
for edge in edges:
|
||||
if not isinstance(edge, Mapping):
|
||||
raise ValueError("edges in workflow graph must be mappings")
|
||||
validated_edges.append(edge)
|
||||
|
||||
return validated_nodes, validated_edges
|
||||
|
||||
@staticmethod
|
||||
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
|
||||
if node_config is None:
|
||||
return None
|
||||
node_data = node_config.get("data")
|
||||
if not isinstance(node_data, Mapping):
|
||||
return None
|
||||
start_node_id = node_data.get("start_node_id")
|
||||
return start_node_id if isinstance(start_node_id, str) else None
|
||||
|
||||
@classmethod
|
||||
def _build_single_node_graph_config(
|
||||
cls,
|
||||
*,
|
||||
graph_config: GraphConfigMapping,
|
||||
node_id: str,
|
||||
node_type_filter_key: str,
|
||||
) -> tuple[GraphConfigObject, NodeConfigDict]:
|
||||
node_configs, edge_configs = cls._get_graph_items(graph_config)
|
||||
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
|
||||
start_node_id = cls._extract_start_node_id(main_node_config)
|
||||
|
||||
filtered_node_configs = [
|
||||
dict(node)
|
||||
for node in node_configs
|
||||
if node.get("id") == node_id
|
||||
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
if not filtered_node_configs:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
filtered_node_ids = {
|
||||
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
|
||||
}
|
||||
filtered_edge_configs = [
|
||||
dict(edge)
|
||||
for edge in edge_configs
|
||||
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
|
||||
]
|
||||
|
||||
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
|
||||
if target_node_config is None:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
return (
|
||||
{
|
||||
"nodes": filtered_node_configs,
|
||||
"edges": filtered_edge_configs,
|
||||
},
|
||||
NodeConfigDictAdapter.validate_python(target_node_config),
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
user_inputs: Mapping[str, object],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
graph_config, target_node_config = self._build_single_node_graph_config(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_type_filter_key=node_type_filter_key,
|
||||
)
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
|
||||
@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
@ -32,6 +32,13 @@ from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SpecialValueScalar: TypeAlias = str | int | float | bool | None
|
||||
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
|
||||
SerializedSpecialValue: TypeAlias = (
|
||||
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
|
||||
)
|
||||
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@staticmethod
|
||||
@ -276,10 +283,10 @@ class WorkflowEntry:
|
||||
@staticmethod
|
||||
def _create_single_node_graph(
|
||||
node_id: str,
|
||||
node_data: dict[str, Any],
|
||||
node_data: Mapping[str, object],
|
||||
node_width: int = 114,
|
||||
node_height: int = 514,
|
||||
) -> dict[str, Any]:
|
||||
) -> SingleNodeGraphConfig:
|
||||
"""
|
||||
Create a minimal graph structure for testing a single node in isolation.
|
||||
|
||||
@ -289,14 +296,14 @@ class WorkflowEntry:
|
||||
:param node_height: height for UI layout (default: 100)
|
||||
:return: graph dictionary with start node and target node
|
||||
"""
|
||||
node_config = {
|
||||
node_config: dict[str, object] = {
|
||||
"id": node_id,
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": node_data,
|
||||
"data": dict(node_data),
|
||||
}
|
||||
start_node_config = {
|
||||
start_node_config: dict[str, object] = {
|
||||
"id": "start",
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
@ -321,7 +328,12 @@ class WorkflowEntry:
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
cls,
|
||||
node_data: Mapping[str, object],
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_inputs: Mapping[str, object],
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Run free node
|
||||
@ -339,6 +351,8 @@ class WorkflowEntry:
|
||||
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
||||
|
||||
node_type = node_data.get("type", "")
|
||||
if not isinstance(node_type, str):
|
||||
raise ValueError("Node type must be a string")
|
||||
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
@ -369,7 +383,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@ -405,30 +419,34 @@ class WorkflowEntry:
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
# data integrity and type information.
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
raise TypeError("handle_special_values expects a mapping input")
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: Any):
|
||||
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
if isinstance(value, Mapping):
|
||||
res: dict[str, SerializedSpecialValue] = {}
|
||||
for k, v in value.items():
|
||||
res[k] = WorkflowEntry._handle_special_values(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
res_list: list[SerializedSpecialValue] = []
|
||||
for item in value:
|
||||
res_list.append(WorkflowEntry._handle_special_values(item))
|
||||
return res_list
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
return dict(value.to_dict())
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from copy import deepcopy
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -33,8 +33,8 @@ def _make_graph_state():
|
||||
],
|
||||
)
|
||||
def test_run_uses_single_node_execution_branch(
|
||||
single_iteration_run: Any,
|
||||
single_loop_run: Any,
|
||||
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
|
||||
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
|
||||
) -> None:
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app"
|
||||
@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "other-node",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": "other-node",
|
||||
"target": "loop-node",
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
original_graph_dict = deepcopy(workflow.graph_dict)
|
||||
|
||||
_, _, graph_runtime_state = _make_graph_state()
|
||||
seen_configs: list[object] = []
|
||||
@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
seen_configs.append(value)
|
||||
return original_validate_python(value)
|
||||
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
|
||||
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
|
||||
):
|
||||
runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
)
|
||||
|
||||
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
||||
assert workflow.graph_dict == original_graph_dict
|
||||
graph_config = graph_init.call_args.kwargs["graph_config"]
|
||||
assert graph_config is not workflow.graph_dict
|
||||
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
|
||||
assert graph_config["edges"] == []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user