mirror of https://github.com/langgenius/dify.git
knowledge fix
This commit is contained in:
parent
3e810bc490
commit
cd3c2f6b00
|
|
@ -10,7 +10,7 @@ class RerankingModelConfig(BaseModel):
|
|||
Reranking Model Config.
|
||||
"""
|
||||
provider: str
|
||||
mode: str
|
||||
model: str
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
|
|
@ -48,4 +48,4 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
|||
dataset_ids: list[str]
|
||||
retrieval_mode: Literal['single', 'multiple']
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig]
|
||||
singleRetrievalConfig: Optional[SingleRetrievalConfig]
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig]
|
||||
|
|
|
|||
|
|
@ -49,9 +49,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
}
|
||||
# retrieve knowledge
|
||||
try:
|
||||
outputs = self._fetch_dataset_retriever(
|
||||
results = self._fetch_dataset_retriever(
|
||||
node_data=node_data, query=query
|
||||
)
|
||||
outputs = {
|
||||
'result': results
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
|
|
@ -95,9 +98,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
|
||||
available_datasets.append(dataset)
|
||||
all_documents = []
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
||||
all_documents = self._single_retrieve(available_datasets, node_data, query)
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
|
||||
all_documents = self._multiple_retrieve(available_datasets, node_data, query)
|
||||
|
||||
document_score_list = {}
|
||||
|
|
@ -262,8 +265,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
model_name = node_data.singleRetrievalConfig.model.name
|
||||
provider_name = node_data.singleRetrievalConfig.model.provider
|
||||
model_name = node_data.single_retrieval_config.model.name
|
||||
provider_name = node_data.single_retrieval_config.model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
|
|
@ -296,14 +299,14 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = node_data.singleRetrievalConfig.model.completion_params
|
||||
completion_params = node_data.single_retrieval_config.model.completion_params
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
|
||||
# get model mode
|
||||
model_mode = node_data.singleRetrievalConfig.model.mode
|
||||
model_mode = node_data.single_retrieval_config.model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue