dify/api/core/workflow/graph.py

152 lines
4.5 KiB
Python

from collections.abc import Callable
from typing import Optional
from pydantic import BaseModel
class GraphNode(BaseModel):
id: str
"""node id"""
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
children_node_ids: list[str] = []
"""children node ids"""
run_condition_callback: Optional[Callable] = None
"""condition function check if the node can be executed"""
node_config: dict
"""original node config"""
source_edge_config: Optional[dict] = None
"""original source edge config"""
target_edge_config: Optional[dict] = None
"""original target edge config"""
def add_child(self, node_id: str) -> None:
self.children_node_ids.append(node_id)
class Graph(BaseModel):
graph_config: dict
"""graph config from workflow"""
graph_nodes: dict[str, GraphNode] = {}
"""graph nodes"""
root_node: Optional[GraphNode] = None
"""root node of the graph"""
def add_edge(self, edge_config: dict,
source_node_config: dict,
target_node_config: dict,
run_condition_callback: Optional[Callable] = None) -> None:
"""
Add edge to the graph
:param edge_config: edge config
:param source_node_config: source node config
:param target_node_config: target node config
:param run_condition_callback: condition callback
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
return
target_node_id = target_node_config.get('id')
if not target_node_id:
return
if source_node_id not in self.graph_nodes:
source_graph_node = GraphNode(
id=source_node_id,
node_config=source_node_config,
children_node_ids=[target_node_id],
target_edge_config=edge_config,
)
self.add_graph_node(source_graph_node)
else:
source_node = self.graph_nodes[source_node_id]
source_node.add_child(target_node_id)
source_node.target_edge_config = edge_config
if target_node_id not in self.graph_nodes:
target_graph_node = GraphNode(
id=target_node_id,
predecessor_node_id=source_node_id,
node_config=target_node_config,
run_condition_callback=run_condition_callback,
source_edge_config=edge_config,
)
self.add_graph_node(target_graph_node)
else:
target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id
target_node.run_condition_callback = run_condition_callback
target_node.source_edge_config = edge_config
def add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
:param graph_node: graph node
"""
if graph_node.id in self.graph_nodes:
return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node
def get_root_node(self) -> Optional[GraphNode]:
"""
Get root node of the graph
:return: root node
"""
return self.root_node
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]:
"""
Get descendants graph of the specific node
:param node_id: node id
:return: descendants graph
"""
if node_id not in self.graph_nodes:
return None
graph_node = self.graph_nodes[node_id]
if not graph_node.children_node_ids:
return None
descendants_graph = Graph(graph_config=self.graph_config)
descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
return descendants_graph
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
"""
Add descendants graph nodes
:param descendants_graph: descendants graph
:param node_id: node id
"""
if node_id not in self.graph_nodes:
return
graph_node = self.graph_nodes[node_id]
descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)