From f19219ab8df8993bdfd4aa1c2e0d402ef2819934 Mon Sep 17 00:00:00 2001 From: jyong Date: Tue, 2 Apr 2024 14:02:49 +0800 Subject: [PATCH] fix knowledge retrival --- .../knowledge_retrieval_node.py | 51 +++++++++++++++++-- ..._agent.py => multi_dataset_react_route.py} | 0 2 files changed, 46 insertions(+), 5 deletions(-) rename api/core/workflow/nodes/knowledge_retrieval/{structed_multi_dataset_router_agent.py => multi_dataset_react_route.py} (100%) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 7c2d26bb82..05227d84d9 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ from typing import Any, cast from flask import Flask, current_app from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError @@ -17,12 +17,12 @@ from core.rerank.rerank import RerankRunner from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter -from core.workflow.nodes.knowledge_retrieval.structed_multi_dataset_router_agent import ReactMultiDatasetRouter +from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document, DocumentSegment, DatasetQuery from models.workflow import WorkflowNodeExecutionStatus default_retrieval_model = { @@ -250,6 +250,9 @@ class KnowledgeRetrievalNode(BaseNode): query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model) + self._on_query(query, [dataset_id]) + if results: + self._on_retrival_end(results) return results return [] @@ -353,9 +356,46 @@ class KnowledgeRetrievalNode(BaseNode): all_documents = rerank_runner.run(query, all_documents, node_data.multiple_retrieval_config.score_threshold, node_data.multiple_retrieval_config.top_k) - + self._on_query(query, dataset_ids) + if all_documents: + self._on_retrival_end(all_documents) return all_documents + def _on_retrival_end(self, documents: list[Document]) -> None: + """Handle retrival end.""" + for document in documents: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata['doc_id'] + ) + + # if 'dataset_id' in document.metadata: + if 'dataset_id' in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + + # add hit count to document segment + query.update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False + ) + + db.session.commit() + + def _on_query(self, query: str, dataset_ids: list[str]) -> None: + """ + Handle query. + """ + for dataset_id in dataset_ids: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=query, + source='app', + source_app_id=self.app_id, + created_by_role=self.user_from.value, + created_by=self.user_id + ) + db.session.add(dataset_query) + db.session.commit() + def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): dataset = db.session.query(Dataset).filter( @@ -392,3 +432,4 @@ class KnowledgeRetrievalNode(BaseNode): ) all_documents.extend(documents) + diff --git a/api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py b/api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py similarity index 100% rename from api/core/workflow/nodes/knowledge_retrieval/structed_multi_dataset_router_agent.py rename to api/core/workflow/nodes/knowledge_retrieval/multi_dataset_react_route.py