From 286741e139227c121632fbf6478d271e0b96ae56 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Mon, 2 Dec 2024 21:13:39 +0800 Subject: [PATCH] fix: iteration node use the main thread pool --- api/core/tools/provider/builtin/comfyui/comfyui.py | 2 +- api/core/workflow/nodes/iteration/iteration_node.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py index bab690af82..114260c48a 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.py +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -15,7 +15,7 @@ class ComfyUIProvider(BuiltinToolProviderController): try: ws.connect(ws_address) - except Exception as e: + except Exception: raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") finally: ws.close() diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 22f242a42f..8aa9811d33 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -162,7 +162,8 @@ class IterationNode(BaseNode[IterationNodeData]): if self.node_data.is_parallel: futures: list[Future] = [] q = Queue() - thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) + thread_pool = graph_engine.workflow_thread_pool_mapping[self.thread_pool_id] + thread_pool._max_workers = self.node_data.parallel_nums for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( self._run_single_iter_parallel, @@ -235,7 +236,10 @@ class IterationNode(BaseNode[IterationNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}, - metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map}, + metadata={ + NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map, + "total_tokens": graph_engine.graph_runtime_state.total_tokens, + }, ) ) except IterationNodeError as e: @@ -258,6 +262,7 @@ class IterationNode(BaseNode[IterationNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, ) ) finally: