mirror of https://github.com/langgenius/dify.git
dataset retrival
This commit is contained in:
parent
c1b0f115d0
commit
3e4bb695e4
|
|
@ -33,10 +33,10 @@ default_retrieval_model = {
|
|||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||
|
|
@ -67,7 +67,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
inputs=variables,
|
||||
error=str(e)
|
||||
)
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[
|
||||
dict[str, Any]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param node_data: node data
|
||||
|
|
@ -224,14 +226,14 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'],
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||
query=variables['#query#'],
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data: node data
|
||||
|
|
@ -333,7 +335,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
|
||||
return all_documents
|
||||
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
|
|
@ -368,4 +370,4 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||
if retrieval_model['reranking_enable'] else None
|
||||
)
|
||||
|
||||
all_documents.extend(documents)
|
||||
all_documents.extend(documents)
|
||||
|
|
|
|||
Loading…
Reference in New Issue