From 4097f7c0691d0d07485cbe15f47fbd9119ea94e2 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 25 Jul 2024 19:39:06 +0800 Subject: [PATCH] add parallel branch output --- .../workflow/graph_engine/entities/event.py | 5 ++- .../workflow/graph_engine/graph_engine.py | 41 +++++++++++++------ .../nodes/answer/answer_stream_processor.py | 2 + .../answer/test_answer_stream_processor.py | 13 +++++- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 071ad164f8..3e669ea49c 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -40,7 +40,10 @@ class GraphRunFailedEvent(BaseGraphEvent): class BaseNodeEvent(GraphEngineEvent): route_node_state: RouteNodeState = Field(..., description="route node state") - parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" # iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index dcbedbeff8..a05a52355c 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -115,6 +115,10 @@ class GraphEngine: raise e def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: + parallel_start_node_id = None + if in_parallel_id: + parallel_start_node_id = start_node_id + next_node_id = start_node_id previous_route_node_state: Optional[RouteNodeState] = None while True: @@ -139,7 +143,8 @@ class GraphEngine: yield from self._run_node( route_node_state=route_node_state, previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None, - parallel_id=in_parallel_id + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id ) self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state @@ -155,7 +160,8 @@ class GraphEngine: route_node_state.failed_reason = str(e) yield NodeRunFailedEvent( route_node_state=route_node_state, - parallel_id=in_parallel_id + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id ) raise e @@ -287,14 +293,16 @@ class GraphEngine: def _run_node(self, route_node_state: RouteNodeState, previous_node_id: Optional[str] = None, - parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: """ Run node """ # trigger node run start event yield NodeRunStartedEvent( route_node_state=route_node_state, - parallel_id=parallel_id + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id ) # get node config @@ -305,7 +313,8 @@ class GraphEngine: route_node_state.failed_reason = f'Node {node_id} config not found.' yield NodeRunFailedEvent( route_node_state=route_node_state, - parallel_id=parallel_id + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id ) return @@ -317,7 +326,8 @@ class GraphEngine: route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.' yield NodeRunFailedEvent( route_node_state=route_node_state, - parallel_id=parallel_id + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id ) return @@ -344,8 +354,9 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: yield NodeRunFailedEvent( + route_node_state=route_node_state, parallel_id=parallel_id, - route_node_state=route_node_state + parallel_start_node_id=parallel_start_node_id ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): @@ -365,24 +376,27 @@ class GraphEngine: ) yield NodeRunSucceededEvent( + route_node_state=route_node_state, parallel_id=parallel_id, - route_node_state=route_node_state + parallel_start_node_id=parallel_start_node_id ) break elif isinstance(item, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( - route_node_state=route_node_state, - parallel_id=parallel_id, chunk_content=item.chunk_content, from_variable_selector=item.from_variable_selector, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, ) elif isinstance(item, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( + retriever_resources=item.retriever_resources, + context=item.context, route_node_state=route_node_state, parallel_id=parallel_id, - retriever_resources=item.retriever_resources, - context=item.context + parallel_start_node_id=parallel_start_node_id, ) except GenerateTaskStoppedException: # trigger node run failed event @@ -390,7 +404,8 @@ class GraphEngine: route_node_state.failed_reason = "Workflow stopped." yield NodeRunFailedEvent( route_node_state=route_node_state, - parallel_id=parallel_id + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, ) return except Exception as e: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 9bec7e9fc2..dafc01cee9 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -140,6 +140,7 @@ class AnswerStreamProcessor: chunk_content=route_chunk.text, route_node_state=event.route_node_state, parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) else: route_chunk = cast(VarGenerateRouteChunk, route_chunk) @@ -162,6 +163,7 @@ class AnswerStreamProcessor: from_variable_selector=value_selector, route_node_state=event.route_node_state, parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, ) self.route_position[answer_node_id] += 1 diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index fe1ebffa4d..6d242a078f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -31,9 +31,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve start_at=datetime.now(timezone.utc).replace(tzinfo=None) ) + parallel_id = graph.node_parallel_mapping.get(next_node_id) + parallel_start_node_id = None + if parallel_id: + parallel = graph.parallel_mapping.get(parallel_id) + parallel_start_node_id = parallel.start_from_node_id if parallel else None + yield NodeRunStartedEvent( route_node_state=route_node_state, parallel_id=graph.node_parallel_mapping.get(next_node_id), + parallel_start_node_id=parallel_start_node_id ) if 'llm' in next_node_id: @@ -43,14 +50,16 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve chunk_content=str(i), route_node_state=route_node_state, from_variable_selector=[next_node_id, "text"], - parallel_id=graph.node_parallel_mapping.get(next_node_id), + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id ) route_node_state.status = RouteNodeState.Status.SUCCESS route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) yield NodeRunSucceededEvent( route_node_state=route_node_state, - parallel_id=graph.node_parallel_mapping.get(next_node_id), + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id )