feat(workflow): add thread pool

This commit is contained in:
takatost 2024-09-02 19:02:10 +08:00
parent 70aced0100
commit 166365a502
1 changed files with 17 additions and 16 deletions

View File

@ -1,3 +1,4 @@
from concurrent.futures import ThreadPoolExecutor, wait
import logging
import queue
import threading
@ -64,6 +65,8 @@ class GraphEngine:
max_execution_steps: int,
max_execution_time: int
) -> None:
## init thread pool
self.thread_pool = ThreadPoolExecutor(max_workers=10)
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
@ -368,7 +371,7 @@ class GraphEngine:
q: queue.Queue = queue.Queue()
# Create a list to store the threads
threads = []
futures = []
# new thread
for edge in edge_mappings:
@ -378,17 +381,16 @@ class GraphEngine:
):
continue
thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
threads.append(thread)
thread.start()
futures.append(
self.thread_pool.submit(self._run_parallel_node, **{
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
)
succeeded_count = 0
while True:
@ -401,7 +403,7 @@ class GraphEngine:
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(threads):
if succeeded_count == len(futures):
q.put(None)
continue
@ -410,9 +412,8 @@ class GraphEngine:
except queue.Empty:
continue
# Join all threads
for thread in threads:
thread.join()
# wait all threads
wait(futures)
# get final node id
final_node_id = parallel.end_to_node_id