add model support for kr node single_retrieval_config

This commit is contained in:
takatost 2024-03-16 22:22:08 +08:00
parent 65ed4dc91f
commit 36180b1001
2 changed files with 39 additions and 4 deletions

View File

@ -131,7 +131,8 @@ class WorkflowConverter:
if app_config.dataset:
knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode,
dataset_config=app_config.dataset
dataset_config=app_config.dataset,
model_config=app_config.model
)
if knowledge_retrieval_node:
@ -359,12 +360,15 @@ class WorkflowConverter:
return nodes
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, dataset_config: DatasetEntity) \
def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode,
dataset_config: DatasetEntity,
model_config: ModelConfigEntity) \
-> Optional[dict]:
"""
Convert datasets to Knowledge Retrieval Node
:param new_app_mode: new app mode
:param dataset_config: dataset
:param model_config: model config
:return:
"""
retrieve_config = dataset_config.retrieve_config
@ -385,6 +389,19 @@ class WorkflowConverter:
"query_variable_selector": query_variable_selector,
"dataset_ids": dataset_config.dataset_ids,
"retrieval_mode": retrieve_config.retrieve_strategy.value,
"single_retrieval_config": {
"model": {
"provider": model_config.provider,
"name": model_config.model,
"mode": model_config.mode,
"completion_params": {
**model_config.parameters,
"stop": model_config.stop,
}
}
}
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE
else None,
"multiple_retrieval_config": {
"top_k": retrieve_config.top_k,
"score_threshold": retrieve_config.score_threshold,

View File

@ -206,9 +206,18 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot():
)
)
model_config = ModelConfigEntity(
provider='openai',
model='gpt-4',
mode='chat',
parameters={},
stop=[]
)
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode,
dataset_config=dataset_config
dataset_config=dataset_config,
model_config=model_config
)
assert node["data"]["type"] == "knowledge-retrieval"
@ -240,9 +249,18 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app():
)
)
model_config = ModelConfigEntity(
provider='openai',
model='gpt-4',
mode='chat',
parameters={},
stop=[]
)
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode,
dataset_config=dataset_config
dataset_config=dataset_config,
model_config=model_config
)
assert node["data"]["type"] == "knowledge-retrieval"