From 6aecf42b6e5d05659ba589f62dc1d6645ba85de9 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 22:32:33 +0800 Subject: [PATCH] fix prompt transform bugs --- api/core/prompt/advanced_prompt_transform.py | 26 ++- api/core/prompt/prompt_transform.py | 4 +- api/core/prompt/simple_prompt_transform.py | 2 +- .../prompt/test_advanced_prompt_transform.py | 193 ++++++++++++++++++ .../prompt/test_simple_prompt_transform.py | 46 ++++- 5 files changed, 251 insertions(+), 20 deletions(-) create mode 100644 api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 397f708f1f..0ed9ec352c 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -20,7 +20,7 @@ from core.prompt.prompt_transform import PromptTransform from core.prompt.simple_prompt_transform import ModelMode -class AdvancePromptTransform(PromptTransform): +class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ @@ -74,10 +74,10 @@ class AdvancePromptTransform(PromptTransform): prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix - self._set_histories_variable( + prompt_inputs = self._set_histories_variable( memory=memory, raw_prompt=raw_prompt, role_prefix=role_prefix, @@ -104,7 +104,7 @@ class AdvancePromptTransform(PromptTransform): def _get_chat_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - query: str, + query: Optional[str], files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], @@ -122,7 +122,7 @@ class AdvancePromptTransform(PromptTransform): prompt_template = PromptTemplateParser(template=raw_prompt) prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) prompt = prompt_template.format( prompt_inputs @@ -136,7 +136,7 @@ class AdvancePromptTransform(PromptTransform): prompt_messages.append(AssistantPromptMessage(content=prompt)) if memory: - self._append_chat_histories(memory, prompt_messages, model_config) + prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config) if files: prompt_message_contents = [TextPromptMessageContent(data=query)] @@ -157,7 +157,7 @@ class AdvancePromptTransform(PromptTransform): last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data=query)] + prompt_message_contents = [TextPromptMessageContent(data='')] # not for query for file in files: prompt_message_contents.append(file.prompt_message_content) @@ -165,26 +165,30 @@ class AdvancePromptTransform(PromptTransform): return prompt_messages - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: if '#context#' in prompt_template.variable_keys: if context: prompt_inputs['#context#'] = context else: prompt_inputs['#context#'] = '' - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + return prompt_inputs + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: if '#query#' in prompt_template.variable_keys: if query: prompt_inputs['#query#'] = query else: prompt_inputs['#query#'] = '' + return prompt_inputs + def _set_histories_variable(self, memory: TokenBufferMemory, raw_prompt: str, role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, prompt_template: PromptTemplateParser, prompt_inputs: dict, - model_config: ModelConfigEntity) -> None: + model_config: ModelConfigEntity) -> dict: if '#histories#' in prompt_template.variable_keys: if memory: inputs = {'#histories#': '', **prompt_inputs} @@ -205,3 +209,5 @@ class AdvancePromptTransform(PromptTransform): prompt_inputs['#histories#'] = histories else: prompt_inputs['#histories#'] = '' + + return prompt_inputs diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index c0f70ae0bb..9596976b6e 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -10,12 +10,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class PromptTransform: def _append_chat_histories(self, memory: TokenBufferMemory, prompt_messages: list[PromptMessage], - model_config: ModelConfigEntity) -> None: + model_config: ModelConfigEntity) -> list[PromptMessage]: if memory: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, rest_tokens) prompt_messages.extend(histories) + return prompt_messages + def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int: rest_tokens = 2000 diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index a51cc86e8b..2f98fbcae8 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -177,7 +177,7 @@ class SimplePromptTransform(PromptTransform): if prompt: prompt_messages.append(SystemPromptMessage(content=prompt)) - self._append_chat_histories( + prompt_messages = self._append_chat_histories( memory=memory, prompt_messages=prompt_messages, model_config=model_config diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py new file mode 100644 index 0000000000..65a160a8e5 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock + +import pytest + +from core.entities.application_entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ + ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity +from core.file.file_obj import FileObj, FileType, FileTransferMethod +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.prompt_template import PromptTemplateParser +from models.model import Conversation + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt=prompt_template, + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", + assistant="Assistant" + ) + ) + ) + inputs = { + "name": "John" + } + files = [] + context = "I am superman." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_completion_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 1 + assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({ + "#context#": context, + "#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " + f"{prompt.content}" for prompt in history_prompt_messages]), + **inputs, + }) + + +def test__get_chat_model_prompt_messages(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [] + query = "Hi2." + + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi1."), + AssistantPromptMessage(content="Hello1!") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=query, + files=files, + context=context, + memory=memory, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 6 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + assert prompt_messages[5].content == query + + +def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=None, + files=files, + context=context, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 3 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + + +def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): + model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args + + files = [ + FileObj( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + url="https://example.com/image1.jpg", + file_config={ + "image": { + "detail": "high", + } + } + ) + ] + + prompt_transform = AdvancedPromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + prompt_messages = prompt_transform._get_chat_model_prompt_messages( + prompt_template_entity=prompt_template_entity, + inputs=inputs, + query=None, + files=files, + context=context, + memory=None, + model_config=model_config_mock + ) + + assert len(prompt_messages) == 4 + assert prompt_messages[0].role == PromptMessageRole.SYSTEM + assert prompt_messages[0].content == PromptTemplateParser( + template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text + ).format({**inputs, "#context#": context}) + assert isinstance(prompt_messages[3].content, list) + assert len(prompt_messages[3].content) == 2 + assert prompt_messages[3].content[1].data == files[0].url + + +@pytest.fixture +def get_chat_model_args(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ) + ) + + inputs = { + "name": "John" + } + + context = "I am superman." + + return model_config_mock, prompt_template_entity, inputs, context diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index cb6ad02541..c174983e38 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,8 +1,10 @@ from unittest.mock import MagicMock from core.entities.application_entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage from core.prompt.simple_prompt_transform import SimplePromptTransform -from models.model import AppMode +from models.model import AppMode, Conversation def test_get_common_chat_app_prompt_template_with_pcqm(): @@ -141,7 +143,16 @@ def test__get_chat_model_prompt_messages(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-4' + memory_mock = MagicMock(spec=TokenBufferMemory) + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory_mock.get_history_prompt_messages.return_value = history_prompt_messages + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) + pre_prompt = "You are a helpful assistant {{name}}." inputs = { "name": "John" @@ -154,7 +165,7 @@ def test__get_chat_model_prompt_messages(): query=query, files=[], context=context, - memory=None, + memory=memory_mock, model_config=model_config_mock ) @@ -171,9 +182,11 @@ def test__get_chat_model_prompt_messages(): full_inputs = {**inputs, '#context#': context} real_system_prompt = prompt_template['prompt_template'].format(full_inputs) - assert len(prompt_messages) == 2 + assert len(prompt_messages) == 4 assert prompt_messages[0].content == real_system_prompt - assert prompt_messages[1].content == query + assert prompt_messages[1].content == history_prompt_messages[0].content + assert prompt_messages[2].content == history_prompt_messages[1].content + assert prompt_messages[3].content == query def test__get_completion_model_prompt_messages(): @@ -181,7 +194,19 @@ def test__get_completion_model_prompt_messages(): model_config_mock.provider = 'openai' model_config_mock.model = 'gpt-3.5-turbo-instruct' + memory = TokenBufferMemory( + conversation=Conversation(), + model_instance=model_config_mock + ) + + history_prompt_messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Hello") + ] + memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages) + prompt_transform = SimplePromptTransform() + prompt_transform._calculate_rest_token = MagicMock(return_value=2000) pre_prompt = "You are a helpful assistant {{name}}." inputs = { "name": "John" @@ -194,7 +219,7 @@ def test__get_completion_model_prompt_messages(): query=query, files=[], context=context, - memory=None, + memory=memory, model_config=model_config_mock ) @@ -205,12 +230,17 @@ def test__get_completion_model_prompt_messages(): pre_prompt=pre_prompt, has_context=True, query_in_prompt=True, - with_memory_prompt=False, + with_memory_prompt=True, ) - full_inputs = {**inputs, '#context#': context, '#query#': query} + prompt_rules = prompt_template['prompt_rules'] + full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text( + max_token_limit=2000, + ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', + human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + )} real_prompt = prompt_template['prompt_template'].format(full_inputs) assert len(prompt_messages) == 1 - assert stops == prompt_template['prompt_rules'].get('stops') + assert stops == prompt_rules.get('stops') assert prompt_messages[0].content == real_prompt