From 5d7865737fb4bcfbf3c2490993adad8f8090c29d Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 16 Aug 2024 22:47:58 +0800 Subject: [PATCH] fix(workflow): issues in workflow parallels --- .../app/apps/workflow_logging_callback.py | 26 +++++++++---------- .../workflow/graph_engine/graph_engine.py | 18 ++++++++----- .../nodes/answer/answer_stream_processor.py | 4 ++- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py index dbbe027acb..4e8f3644b1 100644 --- a/api/core/app/apps/workflow_logging_callback.py +++ b/api/core/app/apps/workflow_logging_callback.py @@ -39,11 +39,11 @@ class WorkflowLoggingCallback(WorkflowCallback): event: GraphEngineEvent ) -> None: if isinstance(event, GraphRunStartedEvent): - self.print_text("\n[on_workflow_run_started]", color='pink') + self.print_text("\n[GraphRunStartedEvent]", color='pink') elif isinstance(event, GraphRunSucceededEvent): - self.print_text("\n[on_workflow_run_succeeded]", color='green') + self.print_text("\n[GraphRunSucceededEvent]", color='green') elif isinstance(event, GraphRunFailedEvent): - self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red') + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red') elif isinstance(event, NodeRunStartedEvent): self.on_workflow_node_execute_started( event=event @@ -90,12 +90,10 @@ class WorkflowLoggingCallback(WorkflowCallback): """ Workflow node execute started """ - route_node_state = event.route_node_state - node_type = event.node_type.value - self.print_text("\n[NodeRunStartedEvent]", color='yellow') - self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow') - self.print_text(f"Type: {node_type}", color='yellow') + self.print_text(f"Node ID: {event.node_id}", color='yellow') + self.print_text(f"Node Title: {event.node_data.title}", color='yellow') + self.print_text(f"Type: {event.node_type.value}", color='yellow') def on_workflow_node_execute_succeeded( self, @@ -105,11 +103,11 @@ class WorkflowLoggingCallback(WorkflowCallback): Workflow node execute succeeded """ route_node_state = event.route_node_state - node_type = event.node_type.value self.print_text("\n[NodeRunSucceededEvent]", color='green') - self.print_text(f"Node ID: {route_node_state.node_id}", color='green') - self.print_text(f"Type: {node_type}", color='green') + self.print_text(f"Node ID: {event.node_id}", color='green') + self.print_text(f"Node Title: {event.node_data.title}", color='green') + self.print_text(f"Type: {event.node_type.value}", color='green') if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result @@ -132,11 +130,11 @@ class WorkflowLoggingCallback(WorkflowCallback): Workflow node execute failed """ route_node_state = event.route_node_state - node_type = event.node_type.value self.print_text("\n[NodeRunFailedEvent]", color='red') - self.print_text(f"Node ID: {route_node_state.node_id}", color='red') - self.print_text(f"Type: {node_type}", color='red') + self.print_text(f"Node ID: {event.node_id}", color='red') + self.print_text(f"Node Title: {event.node_data.title}", color='red') + self.print_text(f"Type: {event.node_type.value}", color='red') if route_node_state.node_run_result: node_run_result = route_node_state.node_run_result diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index f4952f7bdf..fa5cf28868 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -320,14 +320,18 @@ class GraphEngine: break yield event - if isinstance(event, ParallelBranchRunSucceededEvent): - succeeded_count += 1 - if succeeded_count == len(threads): - q.put(None) + if isinstance(event, NodeRunSucceededEvent) and event.node_data.title == 'LLM 4': + print("LLM 4 succeeded") - continue - elif isinstance(event, ParallelBranchRunFailedEvent): - raise GraphRunFailedError(event.error) + if event.parallel_id == parallel_id: + if isinstance(event, ParallelBranchRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(threads): + q.put(None) + + continue + elif isinstance(event, ParallelBranchRunFailedEvent): + raise GraphRunFailedError(event.error) except queue.Empty: continue diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 9904d0f943..9982ab03ef 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -40,7 +40,7 @@ class AnswerStreamProcessor(StreamProcessor): if event.in_iteration_id: yield event continue - + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ event.route_node_state.node_id @@ -55,6 +55,8 @@ class AnswerStreamProcessor(StreamProcessor): yield event elif isinstance(event, NodeRunSucceededEvent): yield event + if event.node_data.title == 'LLM 4': + print("LLM 4 succeeded1") if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # update self.route_position after all stream event finished for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: