fix(workflow): issues in workflow parallels

This commit is contained in:
takatost 2024-08-16 22:47:58 +08:00
parent 352c45c8a2
commit 5d7865737f
3 changed files with 26 additions and 22 deletions

View File

@ -39,11 +39,11 @@ class WorkflowLoggingCallback(WorkflowCallback):
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[on_workflow_run_started]", color='pink')
self.print_text("\n[GraphRunStartedEvent]", color='pink')
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[on_workflow_run_succeeded]", color='green')
self.print_text("\n[GraphRunSucceededEvent]", color='green')
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
event=event
@ -90,12 +90,10 @@ class WorkflowLoggingCallback(WorkflowCallback):
"""
Workflow node execute started
"""
route_node_state = event.route_node_state
node_type = event.node_type.value
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {route_node_state.node_id}", color='yellow')
self.print_text(f"Type: {node_type}", color='yellow')
self.print_text(f"Node ID: {event.node_id}", color='yellow')
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
self.print_text(f"Type: {event.node_type.value}", color='yellow')
def on_workflow_node_execute_succeeded(
self,
@ -105,11 +103,11 @@ class WorkflowLoggingCallback(WorkflowCallback):
Workflow node execute succeeded
"""
route_node_state = event.route_node_state
node_type = event.node_type.value
self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {route_node_state.node_id}", color='green')
self.print_text(f"Type: {node_type}", color='green')
self.print_text(f"Node ID: {event.node_id}", color='green')
self.print_text(f"Node Title: {event.node_data.title}", color='green')
self.print_text(f"Type: {event.node_type.value}", color='green')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
@ -132,11 +130,11 @@ class WorkflowLoggingCallback(WorkflowCallback):
Workflow node execute failed
"""
route_node_state = event.route_node_state
node_type = event.node_type.value
self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {route_node_state.node_id}", color='red')
self.print_text(f"Type: {node_type}", color='red')
self.print_text(f"Node ID: {event.node_id}", color='red')
self.print_text(f"Node Title: {event.node_data.title}", color='red')
self.print_text(f"Type: {event.node_type.value}", color='red')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result

View File

@ -320,14 +320,18 @@ class GraphEngine:
break
yield event
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(threads):
q.put(None)
if isinstance(event, NodeRunSucceededEvent) and event.node_data.title == 'LLM 4':
print("LLM 4 succeeded")
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(threads):
q.put(None)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
except queue.Empty:
continue

View File

@ -40,7 +40,7 @@ class AnswerStreamProcessor(StreamProcessor):
if event.in_iteration_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
@ -55,6 +55,8 @@ class AnswerStreamProcessor(StreamProcessor):
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.node_data.title == 'LLM 4':
print("LLM 4 succeeded1")
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: