From 36180b1001c5480987255b43cd39464bc461e78d Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 16 Mar 2024 22:22:08 +0800 Subject: [PATCH] add model support for kr node single_retrieval_config --- api/services/workflow/workflow_converter.py | 21 ++++++++++++++++-- .../workflow/test_workflow_converter.py | 22 +++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 953c5c5a3c..b1b0b2f315 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -131,7 +131,8 @@ class WorkflowConverter: if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, - dataset_config=app_config.dataset + dataset_config=app_config.dataset, + model_config=app_config.model ) if knowledge_retrieval_node: @@ -359,12 +360,15 @@ class WorkflowConverter: return nodes - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \ + def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, + dataset_config: DatasetEntity, + model_config: ModelConfigEntity) \ -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode :param dataset_config: dataset + :param model_config: model config :return: """ retrieve_config = dataset_config.retrieve_config @@ -385,6 +389,19 @@ class WorkflowConverter: "query_variable_selector": query_variable_selector, "dataset_ids": dataset_config.dataset_ids, "retrieval_mode": retrieve_config.retrieve_strategy.value, + "single_retrieval_config": { + "model": { + "provider": model_config.provider, + "name": model_config.model, + "mode": model_config.mode, + "completion_params": { + **model_config.parameters, + "stop": model_config.stop, + } + } + } + if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + else None, "multiple_retrieval_config": { "top_k": retrieve_config.top_k, "score_threshold": retrieve_config.score_threshold, diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 0ca8ae135c..b4a4d6707a 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -206,9 +206,18 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): ) ) + model_config = ModelConfigEntity( + provider='openai', + model='gpt-4', + mode='chat', + parameters={}, + stop=[] + ) + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, - dataset_config=dataset_config + dataset_config=dataset_config, + model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval" @@ -240,9 +249,18 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): ) ) + model_config = ModelConfigEntity( + provider='openai', + model='gpt-4', + mode='chat', + parameters={}, + stop=[] + ) + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( new_app_mode=new_app_mode, - dataset_config=dataset_config + dataset_config=dataset_config, + model_config=model_config ) assert node["data"]["type"] == "knowledge-retrieval"