mirror of https://github.com/langgenius/dify.git
feat(graph_engine): Add NodeExecutionType.ROOT and auto mark skipped in Graph.init
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
d7e0c5f759
commit
bfbb36756a
|
|
@ -57,6 +57,7 @@ class NodeExecutionType(StrEnum):
|
|||
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
|
||||
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
|
||||
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
|
||||
ROOT = "root" # Nodes that can serve as execution entry points
|
||||
|
||||
|
||||
class ErrorStrategy(StrEnum):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .edge import Edge
|
||||
|
|
@ -186,6 +186,72 @@ class Graph:
|
|||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
nodes: dict[str, Node],
|
||||
edges: dict[str, Edge],
|
||||
in_edges: dict[str, list[str]],
|
||||
out_edges: dict[str, list[str]],
|
||||
active_root_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Mark nodes and edges from inactive root branches as skipped.
|
||||
|
||||
Algorithm:
|
||||
1. Mark inactive root nodes as skipped
|
||||
2. For skipped nodes, mark all their outgoing edges as skipped
|
||||
3. For each edge marked as skipped, check its target node:
|
||||
- If ALL incoming edges are skipped, mark the node as skipped
|
||||
- Otherwise, leave the node state unchanged
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
:param edges: mapping of edge ID to edge instance
|
||||
:param in_edges: mapping of node ID to incoming edge IDs
|
||||
:param out_edges: mapping of node ID to outgoing edge IDs
|
||||
:param active_root_id: ID of the active root node
|
||||
"""
|
||||
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
||||
top_level_roots: list[str] = [
|
||||
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
||||
]
|
||||
|
||||
# If there's only one root or the active root is not a top-level root, no marking needed
|
||||
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
||||
return
|
||||
|
||||
# Mark inactive root nodes as skipped
|
||||
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
||||
for root_id in inactive_roots:
|
||||
if root_id in nodes:
|
||||
nodes[root_id].state = NodeState.SKIPPED
|
||||
|
||||
# Recursively mark downstream nodes and edges
|
||||
def mark_downstream(node_id: str) -> None:
|
||||
"""Recursively mark downstream nodes and edges as skipped."""
|
||||
if nodes[node_id].state != NodeState.SKIPPED:
|
||||
return
|
||||
# If this node is skipped, mark all its outgoing edges as skipped
|
||||
out_edge_ids = out_edges.get(node_id, [])
|
||||
for edge_id in out_edge_ids:
|
||||
edge = edges[edge_id]
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
# Check the target node of this edge
|
||||
target_node = nodes[edge.head]
|
||||
in_edge_ids = in_edges.get(target_node.id, [])
|
||||
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
||||
|
||||
# If all incoming edges are skipped, mark the node as skipped
|
||||
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
||||
target_node.state = NodeState.SKIPPED
|
||||
# Recursively process downstream nodes
|
||||
mark_downstream(target_node.id)
|
||||
|
||||
# Process each inactive root and its downstream nodes
|
||||
for root_id in inactive_roots:
|
||||
mark_downstream(root_id)
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
|
|
@ -227,6 +293,9 @@ class Graph:
|
|||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Mark inactive root branches as skipped
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections.abc import Mapping
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
|
@ -11,6 +11,7 @@ from core.workflow.nodes.start.entities import StartNodeData
|
|||
|
||||
class StartNode(Node):
|
||||
node_type = NodeType.START
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,281 @@
|
|||
"""Unit tests for Graph class methods."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.graph.edge import Edge
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node:
|
||||
"""Create a mock node for testing."""
|
||||
node = Mock(spec=Node)
|
||||
node.id = node_id
|
||||
node.execution_type = execution_type
|
||||
node.state = state
|
||||
node.node_type = NodeType.START
|
||||
return node
|
||||
|
||||
|
||||
class TestMarkInactiveRootBranches:
|
||||
"""Test cases for _mark_inactive_root_branches method."""
|
||||
|
||||
def test_single_root_no_marking(self):
|
||||
"""Test that single root graph doesn't mark anything as skipped."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"]}
|
||||
out_edges = {"root1": ["edge1"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_multiple_roots_mark_inactive(self):
|
||||
"""Test marking inactive root branches with multiple root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "root2": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
|
||||
def test_shared_downstream_node(self):
|
||||
"""Test that shared downstream nodes are not skipped if at least one path is active."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"shared": ["edge3", "edge4"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"child1": ["edge3"],
|
||||
"child2": ["edge4"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.SKIPPED
|
||||
assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
|
||||
def test_deep_branch_marking(self):
|
||||
"""Test marking deep branches with multiple levels."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE),
|
||||
"level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE),
|
||||
"level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE),
|
||||
"level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE),
|
||||
"level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"level1_a": ["edge1"],
|
||||
"level1_b": ["edge2"],
|
||||
"level2_a": ["edge3"],
|
||||
"level2_b": ["edge4"],
|
||||
"level3": ["edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"level1_a": ["edge3"],
|
||||
"level1_b": ["edge4"],
|
||||
"level2_b": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["level1_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level1_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level2_a"].state == NodeState.UNKNOWN
|
||||
assert nodes["level2_b"].state == NodeState.SKIPPED
|
||||
assert nodes["level3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.UNKNOWN
|
||||
assert edges["edge4"].state == NodeState.SKIPPED
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
|
||||
def test_non_root_execution_type(self):
|
||||
"""Test that nodes with non-ROOT execution type are not treated as root nodes."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {"child1": ["edge1"], "child2": ["edge2"]}
|
||||
out_edges = {"root1": ["edge1"], "non_root": ["edge2"]}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped
|
||||
assert nodes["child1"].state == NodeState.UNKNOWN
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
|
||||
def test_empty_graph(self):
|
||||
"""Test handling of empty graph structures."""
|
||||
nodes = {}
|
||||
edges = {}
|
||||
in_edges = {}
|
||||
out_edges = {}
|
||||
|
||||
# Should not raise any errors
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent")
|
||||
|
||||
def test_three_roots_mark_two_inactive(self):
|
||||
"""Test with three root nodes where two should be marked inactive."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE),
|
||||
"child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE),
|
||||
"child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"child1": ["edge1"],
|
||||
"child2": ["edge2"],
|
||||
"child3": ["edge3"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2")
|
||||
|
||||
assert nodes["root1"].state == NodeState.SKIPPED
|
||||
assert nodes["root2"].state == NodeState.UNKNOWN # Active root
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["child1"].state == NodeState.SKIPPED
|
||||
assert nodes["child2"].state == NodeState.UNKNOWN
|
||||
assert nodes["child3"].state == NodeState.SKIPPED
|
||||
assert edges["edge1"].state == NodeState.SKIPPED
|
||||
assert edges["edge2"].state == NodeState.UNKNOWN
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
|
||||
def test_convergent_paths(self):
|
||||
"""Test convergent paths where multiple inactive branches lead to same node."""
|
||||
nodes = {
|
||||
"root1": create_mock_node("root1", NodeExecutionType.ROOT),
|
||||
"root2": create_mock_node("root2", NodeExecutionType.ROOT),
|
||||
"root3": create_mock_node("root3", NodeExecutionType.ROOT),
|
||||
"mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE),
|
||||
"mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE),
|
||||
"convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE),
|
||||
}
|
||||
|
||||
edges = {
|
||||
"edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"),
|
||||
"edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"),
|
||||
"edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"),
|
||||
"edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"),
|
||||
"edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"),
|
||||
}
|
||||
|
||||
in_edges = {
|
||||
"mid1": ["edge1"],
|
||||
"mid2": ["edge2"],
|
||||
"convergent": ["edge3", "edge4", "edge5"],
|
||||
}
|
||||
out_edges = {
|
||||
"root1": ["edge1"],
|
||||
"root2": ["edge2"],
|
||||
"root3": ["edge3"],
|
||||
"mid1": ["edge4"],
|
||||
"mid2": ["edge5"],
|
||||
}
|
||||
|
||||
Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1")
|
||||
|
||||
assert nodes["root1"].state == NodeState.UNKNOWN
|
||||
assert nodes["root2"].state == NodeState.SKIPPED
|
||||
assert nodes["root3"].state == NodeState.SKIPPED
|
||||
assert nodes["mid1"].state == NodeState.UNKNOWN
|
||||
assert nodes["mid2"].state == NodeState.SKIPPED
|
||||
assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1
|
||||
assert edges["edge1"].state == NodeState.UNKNOWN
|
||||
assert edges["edge2"].state == NodeState.SKIPPED
|
||||
assert edges["edge3"].state == NodeState.SKIPPED
|
||||
assert edges["edge4"].state == NodeState.UNKNOWN
|
||||
assert edges["edge5"].state == NodeState.SKIPPED
|
||||
Loading…
Reference in New Issue