From a44d3c3eda15ec303480601c858cec6cf01871dd Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 22 Feb 2024 15:15:42 +0800 Subject: [PATCH] fix bugs and add unit tests --- .../model_runtime/entities/model_entities.py | 2 +- .../model_providers/__base/tts_model.py | 4 +- api/core/prompt/simple_prompt_transform.py | 35 +-- api/models/workflow.py | 4 +- api/tests/unit_tests/.gitignore | 1 + api/tests/unit_tests/__init__.py | 0 api/tests/unit_tests/conftest.py | 7 + api/tests/unit_tests/core/__init__.py | 0 api/tests/unit_tests/core/prompt/__init__.py | 0 .../core/prompt/test_prompt_transform.py | 47 ++++ .../prompt/test_simple_prompt_transform.py | 216 ++++++++++++++++++ 11 files changed, 295 insertions(+), 21 deletions(-) create mode 100644 api/tests/unit_tests/.gitignore create mode 100644 api/tests/unit_tests/__init__.py create mode 100644 api/tests/unit_tests/conftest.py create mode 100644 api/tests/unit_tests/core/__init__.py create mode 100644 api/tests/unit_tests/core/prompt/__init__.py create mode 100644 api/tests/unit_tests/core/prompt/test_prompt_transform.py create mode 100644 api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 60cb655c98..7dfd811b4f 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -133,7 +133,7 @@ class ModelPropertyKey(Enum): DEFAULT_VOICE = "default_voice" VOICES = "voices" WORD_LIMIT = "word_limit" - AUDOI_TYPE = "audio_type" + AUDIO_TYPE = "audio_type" MAX_WORKERS = "max_workers" diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 722d80c91e..22e546aad7 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -94,8 +94,8 @@ class TTSModel(AIModel): """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE] + if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties: + return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 6e158bef39..a51cc86e8b 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -45,6 +45,7 @@ class SimplePromptTransform(PromptTransform): """ Simple Prompt Transform for Chatbot App Basic Mode. """ + def get_prompt(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, @@ -154,12 +155,12 @@ class SimplePromptTransform(PromptTransform): } def _get_chat_model_prompt_messages(self, pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] @@ -169,7 +170,7 @@ class SimplePromptTransform(PromptTransform): model_config=model_config, pre_prompt=pre_prompt, inputs=inputs, - query=query, + query=None, context=context ) @@ -187,12 +188,12 @@ class SimplePromptTransform(PromptTransform): return prompt_messages, None def _get_completion_model_prompt_messages(self, pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list[FileObj], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) \ + inputs: dict, + query: str, + context: Optional[str], + files: list[FileObj], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( @@ -259,7 +260,7 @@ class SimplePromptTransform(PromptTransform): provider=provider, model=model ) - + # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] @@ -267,14 +268,16 @@ class SimplePromptTransform(PromptTransform): # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'generate_prompts') json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') - + # Open the JSON file and read its content with open(json_file_path, encoding='utf-8') as json_file: content = json.load(json_file) - + # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content + return content + def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False diff --git a/api/models/workflow.py b/api/models/workflow.py index ed26e98896..95805e7871 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -5,7 +5,6 @@ from sqlalchemy.dialects.postgresql import UUID from extensions.ext_database import db from models.account import Account -from models.model import AppMode class WorkflowType(Enum): @@ -29,13 +28,14 @@ class WorkflowType(Enum): raise ValueError(f'invalid workflow type value {value}') @classmethod - def from_app_mode(cls, app_mode: Union[str, AppMode]) -> 'WorkflowType': + def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': """ Get workflow type from app mode. :param app_mode: app mode :return: workflow type """ + from models.model import AppMode app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT diff --git a/api/tests/unit_tests/.gitignore b/api/tests/unit_tests/.gitignore new file mode 100644 index 0000000000..426667562b --- /dev/null +++ b/api/tests/unit_tests/.gitignore @@ -0,0 +1 @@ +.env.test \ No newline at end of file diff --git a/api/tests/unit_tests/__init__.py b/api/tests/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py new file mode 100644 index 0000000000..afc9802cf1 --- /dev/null +++ b/api/tests/unit_tests/conftest.py @@ -0,0 +1,7 @@ +import os + +# Getting the absolute path of the current file's directory +ABS_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Getting the absolute path of the project's root directory +PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) diff --git a/api/tests/unit_tests/core/__init__.py b/api/tests/unit_tests/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/prompt/__init__.py b/api/tests/unit_tests/core/prompt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py new file mode 100644 index 0000000000..8a260b0507 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +from core.entities.application_entities import ModelConfigEntity +from core.entities.provider_configuration import ProviderModelBundle +from core.model_runtime.entities.message_entities import UserPromptMessage +from core.model_runtime.entities.model_entities import ModelPropertyKey, AIModelEntity, ParameterRule +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.prompt.prompt_transform import PromptTransform + + +def test__calculate_rest_token(): + model_schema_mock = MagicMock(spec=AIModelEntity) + parameter_rule_mock = MagicMock(spec=ParameterRule) + parameter_rule_mock.name = 'max_tokens' + model_schema_mock.parameter_rules = [ + parameter_rule_mock + ] + model_schema_mock.model_properties = { + ModelPropertyKey.CONTEXT_SIZE: 62 + } + + large_language_model_mock = MagicMock(spec=LargeLanguageModel) + large_language_model_mock.get_num_tokens.return_value = 6 + + provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) + provider_model_bundle_mock.model_type_instance = large_language_model_mock + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.model = 'gpt-4' + model_config_mock.credentials = {} + model_config_mock.parameters = { + 'max_tokens': 50 + } + model_config_mock.model_schema = model_schema_mock + model_config_mock.provider_model_bundle = provider_model_bundle_mock + + prompt_transform = PromptTransform() + + prompt_messages = [UserPromptMessage(content="Hello, how are you?")] + rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + + # Validate based on the mock configuration and expected logic + expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] + - model_config_mock.parameters['max_tokens'] + - large_language_model_mock.get_num_tokens.return_value) + assert rest_tokens == expected_rest_tokens + assert rest_tokens == 6 diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py new file mode 100644 index 0000000000..cb6ad02541 --- /dev/null +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -0,0 +1,216 @@ +from unittest.mock import MagicMock + +from core.entities.application_entities import ModelConfigEntity +from core.prompt.simple_prompt_transform import SimplePromptTransform +from models.model import AppMode + + +def test_get_common_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_baichuan_chat_app_prompt_template_with_pcqm(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=True, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['histories_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#'] + + +def test_get_common_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_baichuan_completion_app_prompt_template_with_pcq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant." + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.WORKFLOW, + provider="baichuan", + model="Baichuan2-53B", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + print(prompt_template['prompt_template'].template) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + pre_prompt + '\n' + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_q(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == prompt_rules['query_prompt'] + assert prompt_template['special_variable_keys'] == ['#query#'] + + +def test_get_common_chat_app_prompt_template_with_cq(): + prompt_transform = SimplePromptTransform() + pre_prompt = "" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + prompt_rules = prompt_template['prompt_rules'] + assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt'] + + prompt_rules['query_prompt']) + assert prompt_template['special_variable_keys'] == ['#context#', '#query#'] + + +def test_get_common_chat_app_prompt_template_with_p(): + prompt_transform = SimplePromptTransform() + pre_prompt = "you are {{name}}" + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider="openai", + model="gpt-4", + pre_prompt=pre_prompt, + has_context=False, + query_in_prompt=False, + with_memory_prompt=False, + ) + assert prompt_template['prompt_template'].template == pre_prompt + '\n' + assert prompt_template['custom_variable_keys'] == ['name'] + assert prompt_template['special_variable_keys'] == [] + + +def test__get_chat_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-4' + + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages( + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=None, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=False, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, '#context#': context} + real_system_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 2 + assert prompt_messages[0].content == real_system_prompt + assert prompt_messages[1].content == query + + +def test__get_completion_model_prompt_messages(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = 'gpt-3.5-turbo-instruct' + + prompt_transform = SimplePromptTransform() + pre_prompt = "You are a helpful assistant {{name}}." + inputs = { + "name": "John" + } + context = "yes or no." + query = "How are you?" + prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages( + pre_prompt=pre_prompt, + inputs=inputs, + query=query, + files=[], + context=context, + memory=None, + model_config=model_config_mock + ) + + prompt_template = prompt_transform.get_prompt_template( + app_mode=AppMode.CHAT, + provider=model_config_mock.provider, + model=model_config_mock.model, + pre_prompt=pre_prompt, + has_context=True, + query_in_prompt=True, + with_memory_prompt=False, + ) + + full_inputs = {**inputs, '#context#': context, '#query#': query} + real_prompt = prompt_template['prompt_template'].format(full_inputs) + + assert len(prompt_messages) == 1 + assert stops == prompt_template['prompt_rules'].get('stops') + assert prompt_messages[0].content == real_prompt