diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index a760934020..66bf62771f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -341,19 +341,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # get generate route for stream output answer_node_id = node_config['id'] generate_route = AnswerNode.extract_generate_route_selectors(node_config) - start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id) - if not start_node_id: + start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id) + if not start_node_ids: continue - stream_generate_routes[start_node_id] = StreamGenerateRoute( - answer_node_id=answer_node_id, - generate_route=generate_route - ) + for start_node_id in start_node_ids: + stream_generate_routes[start_node_id] = StreamGenerateRoute( + answer_node_id=answer_node_id, + generate_route=generate_route + ) return stream_generate_routes - def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \ - -> Optional[str]: + def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ + -> list[str]: """ Get answer start at node id. :param graph: graph @@ -364,33 +365,38 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc edges = graph.get('edges') # fetch all ingoing edges from source node - ingoing_edge = None + ingoing_edges = [] for edge in edges: if edge.get('target') == target_node_id: - ingoing_edge = edge - break + ingoing_edges.append(edge) - if not ingoing_edge: - return None + if not ingoing_edges: + return [] - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - return None + start_node_ids = [] + for ingoing_edge in ingoing_edges: + source_node_id = ingoing_edge.get('source') + source_node = next((node for node in nodes if node.get('id') == source_node_id), None) + if not source_node: + continue - node_type = source_node.get('data', {}).get('type') - if node_type in [ - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value - ]: - start_node_id = target_node_id - elif node_type == NodeType.START.value: - start_node_id = source_node_id - else: - start_node_id = self._get_answer_start_at_node_id(graph, source_node_id) + node_type = source_node.get('data', {}).get('type') + if node_type in [ + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER.value + ]: + start_node_id = target_node_id + start_node_ids.append(start_node_id) + elif node_type == NodeType.START.value: + start_node_id = source_node_id + start_node_ids.append(start_node_id) + else: + sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id) + if sub_start_node_ids: + start_node_ids.extend(sub_start_node_ids) - return start_node_id + return start_node_ids def _generate_stream_outputs_when_node_started(self) -> Generator: """