diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 01b2908f85..f5783f37c0 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -129,7 +129,7 @@ class QuestionClassifierNode(LLMNode): :return: """ prompt_transform = AdvancedPromptTransform() - prompt_template = self._get_prompt_template(node_data, query) + prompt_template = self._get_prompt_template(node_data, query, memory) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, @@ -137,14 +137,15 @@ class QuestionClassifierNode(LLMNode): files=[], context=context, memory_config=node_data.memory, - memory=memory, + memory=None, model_config=model_config ) stop = model_config.stop return prompt_messages, stop - def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str) \ + def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str, + memory: Optional[TokenBufferMemory]) \ -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes @@ -152,12 +153,15 @@ class QuestionClassifierNode(LLMNode): class_names_str = ','.join(class_names) instruction = node_data.instruction if node_data.instruction else '' input_text = query - + memory_str = '' + if memory: + memory_str = memory.get_history_prompt_text(max_token_limit=2000, + message_limit=node_data.memory.window.size) prompt_messages = [] if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=QUESTION_CLASSIFIER_SYSTEM_PROMPT + text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) user_prompt_message_1 = ChatModelMessage( @@ -182,14 +186,17 @@ class QuestionClassifierNode(LLMNode): prompt_messages.append(assistant_prompt_message_2) user_prompt_message_3 = ChatModelMessage( role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, categories=class_names_str, + text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, + categories=class_names_str, classification_instructions=instruction) ) prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(input_text=input_text, categories=class_names_str, + text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, + input_text=input_text, + categories=class_names_str, classification_instructions=instruction) ) diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index faf44269ac..829f0257bc 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -9,6 +9,11 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ The input text is in the variable text_field.Categories are specified as a comma-separated list in the variable categories or left empty for automatic determination.Classification instructions may be included to improve the classification accuracy. ### Constraint DO NOT include anything other than the JSON array in your response. + ### Memory + Here is the chat histories between human and assistant, inside XML tags. + + {histories} + """ QUESTION_CLASSIFIER_USER_PROMPT_1 = """ @@ -58,6 +63,9 @@ Assistant:{{"keywords": ["recently", "great experience", "company", "service", " ### Memory Here is the chat histories between human and assistant, inside XML tags. + +{histories} + ### User Input {{"input_text" : ["{input_text}"], "categories" : ["{categories}"],"classification_instruction" : ["{classification_instructions}"]}} ### Assistant Output