mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 16:06:55 +08:00
dataset retrival
This commit is contained in:
parent
c1b0f115d0
commit
3e4bb695e4
@ -33,10 +33,10 @@ default_retrieval_model = {
|
|||||||
'score_threshold_enabled': False
|
'score_threshold_enabled': False
|
||||||
}
|
}
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(BaseNode):
|
|
||||||
|
|
||||||
|
class KnowledgeRetrievalNode(BaseNode):
|
||||||
_node_data_cls = KnowledgeRetrievalNodeData
|
_node_data_cls = KnowledgeRetrievalNodeData
|
||||||
_node_type = NodeType.TOOL
|
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||||
|
|
||||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||||
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
|
||||||
@ -67,7 +67,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
inputs=variables,
|
inputs=variables,
|
||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]:
|
|
||||||
|
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[
|
||||||
|
dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
@ -224,14 +226,14 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
if score_threshold_enabled:
|
if score_threshold_enabled:
|
||||||
score_threshold = retrieval_model_config.get("score_threshold")
|
score_threshold = retrieval_model_config.get("score_threshold")
|
||||||
|
|
||||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'],
|
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
|
||||||
top_k=top_k, score_threshold=score_threshold,
|
query=variables['#query#'],
|
||||||
reranking_model=reranking_model)
|
top_k=top_k, score_threshold=score_threshold,
|
||||||
|
reranking_model=reranking_model)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
|
||||||
|
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||||
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
|
||||||
"""
|
"""
|
||||||
Fetch model config
|
Fetch model config
|
||||||
:param node_data: node data
|
:param node_data: node data
|
||||||
@ -333,7 +335,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
|
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(
|
||||||
Dataset.tenant_id == self.tenant_id,
|
Dataset.tenant_id == self.tenant_id,
|
||||||
@ -368,4 +370,4 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
if retrieval_model['reranking_enable'] else None
|
if retrieval_model['reranking_enable'] else None
|
||||||
)
|
)
|
||||||
|
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user