From 755a9658c7191624daa7fb9713f13698d1834816 Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 18 Aug 2024 20:18:13 +0800 Subject: [PATCH] fix(workflow): add parallel id into published events --- .../advanced_chat/generate_task_pipeline.py | 2 + .../apps/workflow/generate_task_pipeline.py | 2 + api/core/app/apps/workflow_app_runner.py | 18 +++++ api/core/app/entities/queue_entities.py | 36 +++++++++ api/core/app/entities/task_entities.py | 22 +++++- .../task_pipeline/workflow_cycle_manage.py | 26 ++++++- api/core/workflow/entities/node_entities.py | 2 + .../workflow/graph_engine/entities/event.py | 14 ++++ .../workflow/graph_engine/graph_engine.py | 74 ++++++++++++++----- .../nodes/answer/answer_stream_processor.py | 2 - 10 files changed, 170 insertions(+), 28 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index c36f5493ea..80e756307f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -291,6 +291,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc workflow_node_execution = self._handle_workflow_node_execution_success(event) response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) @@ -301,6 +302,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index e9e490c578..f9cb8328fa 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -262,6 +262,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_node_execution = self._handle_workflow_node_execution_success(event) response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) @@ -272,6 +273,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index e77d271706..a07211f581 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -199,6 +199,8 @@ class WorkflowBasedAppRunner(AppRunner): node_data=event.node_data, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id @@ -213,6 +215,8 @@ class WorkflowBasedAppRunner(AppRunner): node_data=event.node_data, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs if event.route_node_state.node_run_result else {}, @@ -233,6 +237,8 @@ class WorkflowBasedAppRunner(AppRunner): node_data=event.node_data, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.route_node_state.start_at, inputs=event.route_node_state.node_run_result.inputs if event.route_node_state.node_run_result else {}, @@ -263,6 +269,8 @@ class WorkflowBasedAppRunner(AppRunner): QueueParallelBranchRunStartedEvent( parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, in_iteration_id=event.in_iteration_id ) ) @@ -271,6 +279,8 @@ class WorkflowBasedAppRunner(AppRunner): QueueParallelBranchRunSucceededEvent( parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, in_iteration_id=event.in_iteration_id ) ) @@ -279,6 +289,8 @@ class WorkflowBasedAppRunner(AppRunner): QueueParallelBranchRunFailedEvent( parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, in_iteration_id=event.in_iteration_id, error=event.error ) @@ -292,6 +304,8 @@ class WorkflowBasedAppRunner(AppRunner): node_data=event.iteration_node_data, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.start_at, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, inputs=event.inputs, @@ -308,6 +322,8 @@ class WorkflowBasedAppRunner(AppRunner): node_data=event.iteration_node_data, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, index=event.index, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, output=event.pre_iteration_output, @@ -322,6 +338,8 @@ class WorkflowBasedAppRunner(AppRunner): node_data=event.iteration_node_data, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, start_at=event.start_at, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, inputs=event.inputs, diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 471c982bd0..04226636d1 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -69,6 +69,10 @@ class QueueIterationStartEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int @@ -91,6 +95,10 @@ class QueueIterationNextEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" node_run_index: int output: Optional[Any] = None # output for the current iteration @@ -121,6 +129,10 @@ class QueueIterationCompletedEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" start_at: datetime node_run_index: int @@ -227,6 +239,10 @@ class QueueNodeStartedEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" start_at: datetime @@ -244,6 +260,10 @@ class QueueNodeSucceededEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" start_at: datetime inputs: Optional[dict[str, Any]] = None @@ -268,6 +288,10 @@ class QueueNodeFailedEvent(AppQueueEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" start_at: datetime inputs: Optional[dict[str, Any]] = None @@ -370,6 +394,10 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" @@ -382,6 +410,10 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" @@ -394,6 +426,10 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent): parallel_id: str parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 6470b1d3fd..41215a931a 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -219,6 +219,8 @@ class NodeStartStreamResponse(StreamResponse): inputs: Optional[dict] = None created_at: int extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -238,7 +240,9 @@ class NodeStartStreamResponse(StreamResponse): "predecessor_node_id": self.data.predecessor_node_id, "inputs": None, "created_at": self.data.created_at, - "extras": {} + "extras": {}, + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, } } @@ -268,6 +272,8 @@ class NodeFinishStreamResponse(StreamResponse): created_at: int finished_at: int files: Optional[list[dict]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -294,7 +300,9 @@ class NodeFinishStreamResponse(StreamResponse): "execution_metadata": None, "created_at": self.data.created_at, "finished_at": self.data.finished_at, - "files": [] + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, } } @@ -310,6 +318,8 @@ class ParallelBranchStartStreamResponse(StreamResponse): """ parallel_id: str parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None iteration_id: Optional[str] = None created_at: int @@ -329,6 +339,8 @@ class ParallelBranchFinishedStreamResponse(StreamResponse): """ parallel_id: str parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None iteration_id: Optional[str] = None status: str error: Optional[str] = None @@ -356,6 +368,8 @@ class IterationNodeStartStreamResponse(StreamResponse): extras: dict = {} metadata: dict = {} inputs: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -379,6 +393,8 @@ class IterationNodeNextStreamResponse(StreamResponse): created_at: int pre_iteration_output: Optional[Any] = None extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -409,6 +425,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse): execution_metadata: Optional[dict] = None finished_at: int steps: int + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_COMPLETED workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 940859f622..caca6d00b2 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -390,6 +390,8 @@ class WorkflowCycleManage: predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs_dict, created_at=int(workflow_node_execution.created_at.timestamp()), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ), ) @@ -405,10 +407,14 @@ class WorkflowCycleManage: return response def _workflow_node_finish_to_stream_response( - self, task_id: str, workflow_node_execution: WorkflowNodeExecution + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. + :param event: queue node succeeded or failed event :param task_id: task id :param workflow_node_execution: workflow node execution :return: @@ -436,6 +442,8 @@ class WorkflowCycleManage: created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ), ) @@ -458,6 +466,8 @@ class WorkflowCycleManage: data=ParallelBranchStartStreamResponse.Data( parallel_id=event.parallel_id, parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, created_at=int(time.time()), ) @@ -482,6 +492,8 @@ class WorkflowCycleManage: data=ParallelBranchFinishedStreamResponse.Data( parallel_id=event.parallel_id, parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed', error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, @@ -513,7 +525,9 @@ class WorkflowCycleManage: created_at=int(time.time()), extras={}, inputs=event.inputs or {}, - metadata=event.metadata or {} + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) ) @@ -536,7 +550,9 @@ class WorkflowCycleManage: index=event.index, pre_iteration_output=event.output, created_at=int(time.time()), - extras={} + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) ) @@ -566,7 +582,9 @@ class WorkflowCycleManage: total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), - steps=event.steps + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) ) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index afb92c9b2d..f4e90864fc 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -58,6 +58,8 @@ class NodeRunMetadataKey(Enum): ITERATION_INDEX = 'iteration_index' PARALLEL_ID = 'parallel_id' PARALLEL_START_NODE_ID = 'parallel_start_node_id' + PARENT_PARALLEL_ID = 'parent_parallel_id' + PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id' class NodeRunResult(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index eae6ffec02..06dc4cb8f4 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -49,6 +49,10 @@ class BaseNodeEvent(GraphEngineEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" @@ -84,7 +88,13 @@ class NodeRunFailedEvent(BaseNodeEvent): class BaseParallelBranchEvent(GraphEngineEvent): parallel_id: str = Field(..., description="parallel id") + """parallel id""" parallel_start_node_id: str = Field(..., description="parallel start node id") + """parallel start node id""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" @@ -115,6 +125,10 @@ class BaseIterationEvent(GraphEngineEvent): """parallel id if node is in parallel""" parallel_start_node_id: Optional[str] = None """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" class IterationRunStartedEvent(BaseIterationEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index fa5cf28868..812710e924 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -226,7 +226,9 @@ class GraphEngine: node_data=node_instance.node_data, route_node_state=route_node_state, parallel_id=in_parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=in_parallel_id, + parent_parallel_start_node_id=parallel_start_node_id ) raise e @@ -304,9 +306,11 @@ class GraphEngine: for edge in edge_mappings: thread = threading.Thread(target=self._run_parallel_node, kwargs={ 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] + 'q': q, 'parallel_id': parallel_id, 'parallel_start_node_id': edge.target_node_id, - 'q': q + 'parent_parallel_id': in_parallel_id, + 'parent_parallel_start_node_id': parallel_start_node_id, }) threads.append(thread) @@ -320,9 +324,6 @@ class GraphEngine: break yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_data.title == 'LLM 4': - print("LLM 4 succeeded") - if event.parallel_id == parallel_id: if isinstance(event, ParallelBranchRunSucceededEvent): succeeded_count += 1 @@ -349,11 +350,15 @@ class GraphEngine: if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: break - def _run_parallel_node(self, - flask_app: Flask, - parallel_id: str, - parallel_start_node_id: str, - q: queue.Queue) -> None: + def _run_parallel_node( + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> None: """ Run parallel nodes """ @@ -361,7 +366,9 @@ class GraphEngine: try: q.put(ParallelBranchRunStartedEvent( parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id )) # run node @@ -376,12 +383,16 @@ class GraphEngine: # trigger graph run success event q.put(ParallelBranchRunSucceededEvent( parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id )) except GraphRunFailedError as e: q.put(ParallelBranchRunFailedEvent( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, error=e.error )) except Exception as e: @@ -389,16 +400,22 @@ class GraphEngine: q.put(ParallelBranchRunFailedEvent( parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, error=str(e) )) finally: db.session.remove() - def _run_node(self, - node_instance: BaseNode, - route_node_state: RouteNodeState, - parallel_id: Optional[str] = None, - parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: + def _run_node( + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: """ Run node """ @@ -411,7 +428,9 @@ class GraphEngine: route_node_state=route_node_state, predecessor_node_id=node_instance.previous_node_id, parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id ) db.session.close() @@ -425,6 +444,8 @@ class GraphEngine: # add parallel info to iteration event item.parallel_id = parallel_id item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id yield item else: @@ -441,7 +462,9 @@ class GraphEngine: node_data=node_instance.node_data, route_node_state=route_node_state, parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): @@ -471,6 +494,9 @@ class GraphEngine: run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + if parent_parallel_id and parent_parallel_start_node_id: + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id yield NodeRunSucceededEvent( id=node_instance.id, @@ -479,7 +505,9 @@ class GraphEngine: node_data=node_instance.node_data, route_node_state=route_node_state, parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id ) break @@ -494,6 +522,8 @@ class GraphEngine: route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id ) elif isinstance(item, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( @@ -506,6 +536,8 @@ class GraphEngine: route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id ) except GenerateTaskStoppedException: # trigger node run failed event @@ -520,6 +552,8 @@ class GraphEngine: route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id ) return except Exception as e: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 9982ab03ef..40363df0f5 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -55,8 +55,6 @@ class AnswerStreamProcessor(StreamProcessor): yield event elif isinstance(event, NodeRunSucceededEvent): yield event - if event.node_data.title == 'LLM 4': - print("LLM 4 succeeded1") if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # update self.route_position after all stream event finished for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: