diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e64ac25ab1..bd893b17f1 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -100,7 +100,7 @@ class LLMGenerator: return name @classmethod - def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): + def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]: output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() @@ -119,6 +119,8 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt)] + questions: Sequence[str] = [] + try: response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index e78859cc1a..eec771181f 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -1,17 +1,26 @@ import json +import logging import re +from collections.abc import Sequence from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +logger = logging.getLogger(__name__) + class SuggestedQuestionsAfterAnswerOutputParser: def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT - def parse(self, text: str): + def parse(self, text: str) -> Sequence[str]: action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) + questions: list[str] = [] if action_match is not None: - json_obj = json.loads(action_match.group(0).strip()) - else: - json_obj = [] - return json_obj + try: + json_obj = json.loads(action_match.group(0).strip()) + except json.JSONDecodeError as exc: + logger.warning("Failed to decode suggested questions payload: %s", exc) + else: + if isinstance(json_obj, list): + questions = [question for question in json_obj if isinstance(question, str)] + return questions diff --git a/api/services/message_service.py b/api/services/message_service.py index 9fdff18622..7ed56d80f2 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -288,9 +288,10 @@ class MessageService: ) with measure_time() as timer: - questions: list[str] = LLMGenerator.generate_suggested_questions_after_answer( + questions_sequence = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) + questions: list[str] = list(questions_sequence) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id)