From 1d88b62e25eb0d0fe2647b08de8cfee36cfa7548 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 20 Aug 2024 23:28:11 +0800 Subject: [PATCH] fix(workflow): fix node link to previous node issue --- .../workflow/graph_engine/entities/graph.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 4061075ba5..714eb55fd2 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,6 +1,7 @@ import uuid from collections.abc import Mapping from typing import Any, Optional, cast +from cycler import V from pydantic import BaseModel, Field @@ -152,6 +153,12 @@ class Graph(BaseModel): if not root_node_id or root_node_id not in root_node_ids: raise ValueError(f"Root node id {root_node_id} not found in the graph") + + # Check whether it is connected to the previous node + cls._check_connected_to_previous_node( + route=[root_node_id], + edge_mapping=edge_mapping + ) # fetch all node ids from root node node_ids = [root_node_id] @@ -267,6 +274,30 @@ class Graph(BaseModel): node_id=graph_edge.target_node_id ) + @classmethod + def _check_connected_to_previous_node( + cls, + route: list[str], + edge_mapping: dict[str, list[GraphEdge]] + ) -> None: + """ + Check whether it is connected to the previous node + """ + new_route = list(route) + + for graph_edge in edge_mapping.get(new_route[-1], []): + if not graph_edge.target_node_id: + continue + + if graph_edge.target_node_id in new_route: + raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.") + + new_route.append(graph_edge.target_node_id) + cls._check_connected_to_previous_node( + route=new_route, + edge_mapping=edge_mapping, + ) + @classmethod def _recursively_add_parallels(cls, edge_mapping: dict[str, list[GraphEdge]],