From 166365a5021abc1883872a634f548628f1866429 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 2 Sep 2024 19:02:10 +0800 Subject: [PATCH] feat(workflow): add thread pool --- .../workflow/graph_engine/graph_engine.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b3b64722c5..3bb96a619f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor, wait import logging import queue import threading @@ -64,6 +65,8 @@ class GraphEngine: max_execution_steps: int, max_execution_time: int ) -> None: + ## init thread pool + self.thread_pool = ThreadPoolExecutor(max_workers=10) self.graph = graph self.init_params = GraphInitParams( tenant_id=tenant_id, @@ -368,7 +371,7 @@ class GraphEngine: q: queue.Queue = queue.Queue() # Create a list to store the threads - threads = [] + futures = [] # new thread for edge in edge_mappings: @@ -378,17 +381,16 @@ class GraphEngine: ): continue - thread = threading.Thread(target=self._run_parallel_node, kwargs={ - 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] - 'q': q, - 'parallel_id': parallel_id, - 'parallel_start_node_id': edge.target_node_id, - 'parent_parallel_id': in_parallel_id, - 'parent_parallel_start_node_id': parallel_start_node_id, - }) - - threads.append(thread) - thread.start() + futures.append( + self.thread_pool.submit(self._run_parallel_node, **{ + 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] + 'q': q, + 'parallel_id': parallel_id, + 'parallel_start_node_id': edge.target_node_id, + 'parent_parallel_id': in_parallel_id, + 'parent_parallel_start_node_id': parallel_start_node_id, + }) + ) succeeded_count = 0 while True: @@ -401,7 +403,7 @@ class GraphEngine: if event.parallel_id == parallel_id: if isinstance(event, ParallelBranchRunSucceededEvent): succeeded_count += 1 - if succeeded_count == len(threads): + if succeeded_count == len(futures): q.put(None) continue @@ -410,9 +412,8 @@ class GraphEngine: except queue.Empty: continue - # Join all threads - for thread in threads: - thread.join() + # wait all threads + wait(futures) # get final node id final_node_id = parallel.end_to_node_id