mirror of
https://github.com/langgenius/dify.git
synced 2026-04-20 14:17:18 +08:00
fix knowledge retrival
This commit is contained in:
parent
ef39fa3fb2
commit
f19219ab8d
@ -4,7 +4,7 @@ from typing import Any, cast
|
|||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity, InvokeFrom
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
from core.entities.agent_entities import PlanningStrategy
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
@ -17,12 +17,12 @@ from core.rerank.rerank import RerankRunner
|
|||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode, UserFrom
|
||||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||||
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
from core.workflow.nodes.knowledge_retrieval.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||||
from core.workflow.nodes.knowledge_retrieval.structed_multi_dataset_router_agent import ReactMultiDatasetRouter
|
from core.workflow.nodes.knowledge_retrieval.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment, DatasetQuery
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
default_retrieval_model = {
|
default_retrieval_model = {
|
||||||
@ -250,6 +250,9 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
query=query,
|
query=query,
|
||||||
top_k=top_k, score_threshold=score_threshold,
|
top_k=top_k, score_threshold=score_threshold,
|
||||||
reranking_model=reranking_model)
|
reranking_model=reranking_model)
|
||||||
|
self._on_query(query, [dataset_id])
|
||||||
|
if results:
|
||||||
|
self._on_retrival_end(results)
|
||||||
return results
|
return results
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -353,9 +356,46 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
all_documents = rerank_runner.run(query, all_documents,
|
all_documents = rerank_runner.run(query, all_documents,
|
||||||
node_data.multiple_retrieval_config.score_threshold,
|
node_data.multiple_retrieval_config.score_threshold,
|
||||||
node_data.multiple_retrieval_config.top_k)
|
node_data.multiple_retrieval_config.top_k)
|
||||||
|
self._on_query(query, dataset_ids)
|
||||||
|
if all_documents:
|
||||||
|
self._on_retrival_end(all_documents)
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
|
def _on_retrival_end(self, documents: list[Document]) -> None:
|
||||||
|
"""Handle retrival end."""
|
||||||
|
for document in documents:
|
||||||
|
query = db.session.query(DocumentSegment).filter(
|
||||||
|
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
# if 'dataset_id' in document.metadata:
|
||||||
|
if 'dataset_id' in document.metadata:
|
||||||
|
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||||
|
|
||||||
|
# add hit count to document segment
|
||||||
|
query.update(
|
||||||
|
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||||
|
synchronize_session=False
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
def _on_query(self, query: str, dataset_ids: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
Handle query.
|
||||||
|
"""
|
||||||
|
for dataset_id in dataset_ids:
|
||||||
|
dataset_query = DatasetQuery(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
content=query,
|
||||||
|
source='app',
|
||||||
|
source_app_id=self.app_id,
|
||||||
|
created_by_role=self.user_from.value,
|
||||||
|
created_by=self.user_id
|
||||||
|
)
|
||||||
|
db.session.add(dataset_query)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(
|
||||||
@ -392,3 +432,4 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user