mirror of https://github.com/langgenius/dify.git
feat(workflow): add thread pool
This commit is contained in:
parent
70aced0100
commit
166365a502
|
|
@ -1,3 +1,4 @@
|
|||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
|
|
@ -64,6 +65,8 @@ class GraphEngine:
|
|||
max_execution_steps: int,
|
||||
max_execution_time: int
|
||||
) -> None:
|
||||
## init thread pool
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
self.graph = graph
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -368,7 +371,7 @@ class GraphEngine:
|
|||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
threads = []
|
||||
futures = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
|
|
@ -378,17 +381,16 @@ class GraphEngine:
|
|||
):
|
||||
continue
|
||||
|
||||
thread = threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
futures.append(
|
||||
self.thread_pool.submit(self._run_parallel_node, **{
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
)
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
|
|
@ -401,7 +403,7 @@ class GraphEngine:
|
|||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(threads):
|
||||
if succeeded_count == len(futures):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
|
|
@ -410,9 +412,8 @@ class GraphEngine:
|
|||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# Join all threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
|
|
|
|||
Loading…
Reference in New Issue