dataset retrival

This commit is contained in:
jyong 2024-03-15 16:14:32 +08:00
parent c1b0f115d0
commit 3e4bb695e4
1 changed files with 13 additions and 11 deletions

View File

@ -33,10 +33,10 @@ default_retrieval_model = {
'score_threshold_enabled': False
}
class KnowledgeRetrievalNode(BaseNode):
class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData
_node_type = NodeType.TOOL
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
@ -67,7 +67,9 @@ class KnowledgeRetrievalNode(BaseNode):
inputs=variables,
error=str(e)
)
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[dict[str, Any]]:
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]) -> list[
dict[str, Any]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param node_data: node data
@ -224,14 +226,14 @@ class KnowledgeRetrievalNode(BaseNode):
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=variables['#query#'],
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,
query=variables['#query#'],
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model)
return results
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
@ -333,7 +335,7 @@ class KnowledgeRetrievalNode(BaseNode):
return all_documents
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
@ -368,4 +370,4 @@ class KnowledgeRetrievalNode(BaseNode):
if retrieval_model['reranking_enable'] else None
)
all_documents.extend(documents)
all_documents.extend(documents)