From 958da42f748829e31ebfe4c15ee002e06654d5d7 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 18 Mar 2024 14:28:07 +0800 Subject: [PATCH] fix advanced chat answer --- .../advanced_chat/generate_task_pipeline.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 9c78373d17..a64913d770 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -230,6 +230,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + # generate stream outputs when node started + yield from self._generate_stream_outputs_when_node_started() + yield self._workflow_node_start_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution @@ -423,6 +426,37 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return start_node_id + def _generate_stream_outputs_when_node_started(self) -> Generator: + """ + Generate stream outputs. + :return: + """ + if self._task_state.current_stream_generate_state: + route_chunks = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + # handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(token) + if should_direct_answer: + continue + + self._task_state.answer += token + yield self._message_to_stream_response(token, self._message.id) + time.sleep(0.01) + else: + break + + self._task_state.current_stream_generate_state.current_route_position += 1 + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + def _generate_stream_outputs_when_node_finished(self) -> None: """ Generate stream outputs.