From f11bf9153deee59d773d30d073e272d22f0082bc Mon Sep 17 00:00:00 2001 From: takatost Date: Sun, 25 Feb 2024 13:47:43 +0800 Subject: [PATCH] add more tests --- .../workflow/test_workflow_converter.py | 266 +++++++++++++++++- 1 file changed, 263 insertions(+), 3 deletions(-) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 69cf6afe45..ee9e5eb2fa 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -4,8 +4,12 @@ from unittest.mock import MagicMock import pytest -from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity +from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity, DatasetEntity, \ + DatasetRetrieveConfigEntity, ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ + AdvancedChatMessageEntity, AdvancedCompletionPromptTemplateEntity from core.helper import encrypter +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.message_entities import PromptMessageRole from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import AppMode from services.workflow.workflow_converter import WorkflowConverter @@ -42,9 +46,9 @@ def test__convert_to_start_node(default_variables): assert result["data"]["variables"][2]["variable"] == "select" -def test__convert_to_http_request_node(default_variables): +def test__convert_to_http_request_node_for_chatbot(default_variables): """ - Test convert to http request nodes + Test convert to http request nodes for chatbot :return: """ app_model = MagicMock() @@ -182,3 +186,259 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables): code_node = nodes[1] assert code_node["data"]["type"] == "code" + + +def test__convert_to_knowledge_retrieval_node_for_chatbot(): + new_app_mode = AppMode.CHAT + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", "sys.query"] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_knowledge_retrieval_node_for_workflow_app(): + new_app_mode = AppMode.WORKFLOW + + dataset_config = DatasetEntity( + dataset_ids=["dataset_id_1", "dataset_id_2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable="query", + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={ + 'reranking_provider_name': 'cohere', + 'reranking_model_name': 'rerank-english-v2.0' + } + ) + ) + + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=new_app_mode, + dataset_config=dataset_config + ) + + assert node["data"]["type"] == "knowledge-retrieval" + assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] + assert node["data"]["dataset_ids"] == dataset_config.dataset_ids + assert (node["data"]["retrieval_mode"] + == dataset_config.retrieve_config.retrieve_strategy.value) + assert node["data"]["multiple_retrieval_config"] == { + "top_k": dataset_config.retrieve_config.top_k, + "score_threshold": dataset_config.retrieve_config.score_threshold, + "reranking_model": dataset_config.retrieve_config.reranking_model + } + + +def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are a helpful assistant {{text-input}}, {{paragraph}}, {{select}}." + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert llm_node["data"]["prompts"]['text'] == prompt_template.simple_prompt_template + '\n' + assert llm_node["data"]['context']['enabled'] is False + + +def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-4" + model_mode = LLMMode.CHAT + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[ + AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ]) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], list) + assert len(llm_node["data"]["prompts"]) == len(prompt_template.advanced_chat_prompt_template.messages) + assert llm_node["data"]["prompts"][0]['text'] == prompt_template.advanced_chat_prompt_template.messages[0].text + + +def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): + new_app_mode = AppMode.CHAT + model = "gpt-3.5-turbo-instruct" + model_mode = LLMMode.COMPLETION + + workflow_converter = WorkflowConverter() + start_node = workflow_converter._convert_to_start_node(default_variables) + graph = { + "nodes": [ + start_node + ], + "edges": [] # no need + } + + model_config_mock = MagicMock(spec=ModelConfigEntity) + model_config_mock.provider = 'openai' + model_config_mock.model = model + model_config_mock.mode = model_mode.value + model_config_mock.parameters = {} + model_config_mock.stop = [] + + prompt_template = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" + "Human: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", + assistant="Assistant" + ) + ) + ) + + llm_node = workflow_converter._convert_to_llm_node( + new_app_mode=new_app_mode, + model_config=model_config_mock, + graph=graph, + prompt_template=prompt_template + ) + + assert llm_node["data"]["type"] == "llm" + assert llm_node["data"]["model"]['name'] == model + assert llm_node["data"]['model']["mode"] == model_mode.value + assert llm_node["data"]["variables"] == [{ + "variable": v.variable, + "value_selector": ["start", v.variable] + } for v in default_variables] + assert isinstance(llm_node["data"]["prompts"], dict) + assert llm_node["data"]["prompts"]['text'] == prompt_template.advanced_completion_prompt_template.prompt