Ensure suggested questions parser returns typed sequence (#27104)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN- 2025-10-20 13:01:09 +08:00 committed by GitHub
parent 4c37d650d3
commit 4dccdf9478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 7 deletions

View File

@ -100,7 +100,7 @@ class LLMGenerator:
return name
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
@ -119,6 +119,8 @@ class LLMGenerator:
prompt_messages = [UserPromptMessage(content=prompt)]
questions: Sequence[str] = []
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),

View File

@ -1,17 +1,26 @@
import json
import logging
import re
from collections.abc import Sequence
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str):
def parse(self, text: str) -> Sequence[str]:
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
questions: list[str] = []
if action_match is not None:
json_obj = json.loads(action_match.group(0).strip())
else:
json_obj = []
return json_obj
try:
json_obj = json.loads(action_match.group(0).strip())
except json.JSONDecodeError as exc:
logger.warning("Failed to decode suggested questions payload: %s", exc)
else:
if isinstance(json_obj, list):
questions = [question for question in json_obj if isinstance(question, str)]
return questions

View File

@ -288,9 +288,10 @@ class MessageService:
)
with measure_time() as timer:
questions: list[str] = LLMGenerator.generate_suggested_questions_after_answer(
questions_sequence = LLMGenerator.generate_suggested_questions_after_answer(
tenant_id=app_model.tenant_id, histories=histories
)
questions: list[str] = list(questions_sequence)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)