fix variable assigner multi route

This commit is contained in:
takatost 2024-03-20 22:49:24 +08:00
parent a7e2f9caf0
commit 0d0da9a892
1 changed files with 35 additions and 29 deletions

View File

@ -341,19 +341,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id)
if not start_node_id:
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
stream_generate_routes[start_node_id] = StreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = StreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \
-> Optional[str]:
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
@ -364,33 +365,38 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edge = None
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edge = edge
break
ingoing_edges.append(edge)
if not ingoing_edge:
return None
if not ingoing_edges:
return []
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
return None
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
elif node_type == NodeType.START.value:
start_node_id = source_node_id
else:
start_node_id = self._get_answer_start_at_node_id(graph, source_node_id)
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value:
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_id
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""