mirror of https://github.com/langgenius/dify.git
152 lines
4.5 KiB
Python
152 lines
4.5 KiB
Python
from collections.abc import Callable
|
|
from typing import Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class GraphNode(BaseModel):
|
|
id: str
|
|
"""node id"""
|
|
|
|
predecessor_node_id: Optional[str] = None
|
|
"""predecessor node id"""
|
|
|
|
children_node_ids: list[str] = []
|
|
"""children node ids"""
|
|
|
|
run_condition_callback: Optional[Callable] = None
|
|
"""condition function check if the node can be executed"""
|
|
|
|
node_config: dict
|
|
"""original node config"""
|
|
|
|
source_edge_config: Optional[dict] = None
|
|
"""original source edge config"""
|
|
|
|
target_edge_config: Optional[dict] = None
|
|
"""original target edge config"""
|
|
|
|
def add_child(self, node_id: str) -> None:
|
|
self.children_node_ids.append(node_id)
|
|
|
|
|
|
class Graph(BaseModel):
|
|
graph_config: dict
|
|
"""graph config from workflow"""
|
|
|
|
graph_nodes: dict[str, GraphNode] = {}
|
|
"""graph nodes"""
|
|
|
|
root_node: Optional[GraphNode] = None
|
|
"""root node of the graph"""
|
|
|
|
def add_edge(self, edge_config: dict,
|
|
source_node_config: dict,
|
|
target_node_config: dict,
|
|
run_condition_callback: Optional[Callable] = None) -> None:
|
|
"""
|
|
Add edge to the graph
|
|
|
|
:param edge_config: edge config
|
|
:param source_node_config: source node config
|
|
:param target_node_config: target node config
|
|
:param run_condition_callback: condition callback
|
|
"""
|
|
source_node_id = source_node_config.get('id')
|
|
if not source_node_id:
|
|
return
|
|
|
|
target_node_id = target_node_config.get('id')
|
|
if not target_node_id:
|
|
return
|
|
|
|
if source_node_id not in self.graph_nodes:
|
|
source_graph_node = GraphNode(
|
|
id=source_node_id,
|
|
node_config=source_node_config,
|
|
children_node_ids=[target_node_id],
|
|
target_edge_config=edge_config,
|
|
)
|
|
|
|
self.add_graph_node(source_graph_node)
|
|
else:
|
|
source_node = self.graph_nodes[source_node_id]
|
|
source_node.add_child(target_node_id)
|
|
source_node.target_edge_config = edge_config
|
|
|
|
if target_node_id not in self.graph_nodes:
|
|
target_graph_node = GraphNode(
|
|
id=target_node_id,
|
|
predecessor_node_id=source_node_id,
|
|
node_config=target_node_config,
|
|
run_condition_callback=run_condition_callback,
|
|
source_edge_config=edge_config,
|
|
)
|
|
|
|
self.add_graph_node(target_graph_node)
|
|
else:
|
|
target_node = self.graph_nodes[target_node_id]
|
|
target_node.predecessor_node_id = source_node_id
|
|
target_node.run_condition_callback = run_condition_callback
|
|
target_node.source_edge_config = edge_config
|
|
|
|
def add_graph_node(self, graph_node: GraphNode) -> None:
|
|
"""
|
|
Add graph node to the graph
|
|
|
|
:param graph_node: graph node
|
|
"""
|
|
if graph_node.id in self.graph_nodes:
|
|
return
|
|
|
|
if len(self.graph_nodes) == 0:
|
|
self.root_node = graph_node
|
|
|
|
self.graph_nodes[graph_node.id] = graph_node
|
|
|
|
def get_root_node(self) -> Optional[GraphNode]:
|
|
"""
|
|
Get root node of the graph
|
|
|
|
:return: root node
|
|
"""
|
|
return self.root_node
|
|
|
|
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]:
|
|
"""
|
|
Get descendants graph of the specific node
|
|
|
|
:param node_id: node id
|
|
:return: descendants graph
|
|
"""
|
|
if node_id not in self.graph_nodes:
|
|
return None
|
|
|
|
graph_node = self.graph_nodes[node_id]
|
|
if not graph_node.children_node_ids:
|
|
return None
|
|
|
|
descendants_graph = Graph(graph_config=self.graph_config)
|
|
descendants_graph.add_graph_node(graph_node)
|
|
|
|
for child_node_id in graph_node.children_node_ids:
|
|
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
|
|
|
return descendants_graph
|
|
|
|
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
|
|
"""
|
|
Add descendants graph nodes
|
|
|
|
:param descendants_graph: descendants graph
|
|
:param node_id: node id
|
|
"""
|
|
if node_id not in self.graph_nodes:
|
|
return
|
|
|
|
graph_node = self.graph_nodes[node_id]
|
|
descendants_graph.add_graph_node(graph_node)
|
|
|
|
for child_node_id in graph_node.children_node_ids:
|
|
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|