perf: optimize DatasetRetrieval.retrieve、RetrievalService._deduplicat… (#29981)

This commit is contained in:
wangxiaolei 2025-12-22 20:08:21 +08:00 committed by GitHub
parent 4d8223d517
commit eaf4146e2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 201 additions and 145 deletions

View File

@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = [] documents = []
segment_query_stmt = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
)
if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all()
segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices: for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).where( segment = segment_map.get(chunk_index)
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment: if segment:
documents.append( documents.append(

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, load_only from sqlalchemy.orm import Session, load_only
from configs import dify_config from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
@ -138,37 +139,47 @@ class RetrievalService:
@classmethod @classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search.""" """Deduplicate documents in O(n) while preserving first-seen order.
Rules:
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
- For non-dify documents (or dify without doc_id): deduplicate by content key
(provider, page_content), keeping the first occurrence.
"""
if not documents: if not documents:
return documents return documents
unique_documents = [] # Map of dedup key -> chosen Document
seen_doc_ids = set() chosen: dict[tuple, Document] = {}
# Preserve the order of first appearance of each dedup key
order: list[tuple] = []
for document in documents: for doc in documents:
# For dify provider documents, use doc_id for deduplication is_dify = doc.provider == "dify"
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata: doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids: if is_dify and doc_id:
seen_doc_ids.add(doc_id) key = ("dify", doc_id)
unique_documents.append(document) if key not in chosen:
# If duplicate, keep the one with higher score chosen[key] = doc
elif "score" in document.metadata: order.append(key)
# Find existing document with same doc_id and compare scores else:
for i, existing_doc in enumerate(unique_documents): # Only replace if the new one has a score and it's strictly higher
if ( if "score" in doc.metadata:
existing_doc.metadata new_score = float(doc.metadata.get("score", 0.0))
and existing_doc.metadata.get("doc_id") == doc_id old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0) if new_score > old_score:
): chosen[key] = doc
unique_documents[i] = document
break
else: else:
# For non-dify documents, use content-based deduplication # Content-based dedup for non-dify or dify without doc_id
if document not in unique_documents: content_key = (doc.provider or "dify", doc.page_content)
unique_documents.append(document) if content_key not in chosen:
chosen[content_key] = doc
order.append(content_key)
# If duplicate content appears, we keep the first occurrence (no score comparison)
return unique_documents return [chosen[k] for k in order]
@classmethod @classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None: def _get_dataset(cls, dataset_id: str) -> Dataset | None:
@ -371,58 +382,96 @@ class RetrievalService:
include_segment_ids = set() include_segment_ids = set()
segment_child_map = {} segment_child_map = {}
segment_file_map = {} segment_file_map = {}
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id] valid_dataset_documents = {}
if not dataset_document: image_doc_ids = []
continue child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset_document = dataset_documents[document_id]
# Handle parent-child documents if not dataset_document:
if document.metadata.get("doc_type") == DocType.IMAGE: continue
attachment_info_dict = cls.get_segment_attachment_info( valid_dataset_documents[document_id] = dataset_document
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
if not child_chunk: if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
continue doc_id = document.metadata.get("doc_id") or ""
segment_id = child_chunk.segment_id doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
child_index_node_ids.append(doc_id)
else:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
index_node_ids.append(doc_id)
if not segment_id: image_doc_ids = [i for i in image_doc_ids if i]
continue child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
segment = ( segment_ids = []
session.query(DocumentSegment) index_node_segments: list[DocumentSegment] = []
.where( segments: list[DocumentSegment] = []
DocumentSegment.dataset_id == dataset_document.dataset_id, attachment_map = {}
DocumentSegment.enabled == True, child_chunk_map = {}
DocumentSegment.status == "completed", doc_segment_map = {}
DocumentSegment.id == segment_id,
)
.first()
)
if not segment: with session_factory.create_session() as session:
continue attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
attachment_map[attachment["segment_id"]] = attachment
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
child_chunk_map[i.segment_id] = i
doc_segment_map[i.segment_id] = i.index_node_id
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id.in_(index_node_ids),
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(segment_ids),
)
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
if index_node_segments:
segments.extend(index_node_segments)
for segment in segments:
doc_id = doc_segment_map.get(segment.id)
child_chunk = child_chunk_map.get(segment.id)
attachment_info = attachment_map.get(segment.id)
if doc_id:
document = doc_to_document_map[doc_id]
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
document.metadata.get("document_id")
)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids: if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id) include_segment_ids.add(segment.id)
if child_chunk: if child_chunk:
@ -430,10 +479,10 @@ class RetrievalService:
"id": child_chunk.id, "id": child_chunk.id,
"content": child_chunk.content, "content": child_chunk.content,
"position": child_chunk.position, "position": child_chunk.position,
"score": document.metadata.get("score", 0.0), "score": document.metadata.get("score", 0.0) if document else 0.0,
} }
map_detail = { map_detail = {
"max_score": document.metadata.get("score", 0.0), "max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail], "child_chunks": [child_chunk_detail],
} }
segment_child_map[segment.id] = map_detail segment_child_map[segment.id] = map_detail
@ -452,13 +501,14 @@ class RetrievalService:
"score": document.metadata.get("score", 0.0), "score": document.metadata.get("score", 0.0),
} }
if segment.id in segment_child_map: if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
segment_child_map[segment.id]["max_score"] = max( segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) segment_child_map[segment.id]["max_score"],
document.metadata.get("score", 0.0) if document else 0.0,
) )
else: else:
segment_child_map[segment.id] = { segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0), "max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail], "child_chunks": [child_chunk_detail],
} }
if attachment_info: if attachment_info:
@ -467,46 +517,11 @@ class RetrievalService:
else: else:
segment_file_map[segment.id] = [attachment_info] segment_file_map[segment.id] = [attachment_info]
else: else:
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
if not segment:
continue
if segment.id not in include_segment_ids: if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id) include_segment_ids.add(segment.id)
record = { record = {
"segment": segment, "segment": segment,
"score": document.metadata.get("score"), # type: ignore "score": document.metadata.get("score", 0.0), # type: ignore
} }
if attachment_info: if attachment_info:
segment_file_map[segment.id] = [attachment_info] segment_file_map[segment.id] = [attachment_info]
@ -522,7 +537,7 @@ class RetrievalService:
for record in records: for record in records:
if record["segment"].id in segment_child_map: if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
if record["segment"].id in segment_file_map: if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
@ -565,6 +580,8 @@ class RetrievalService:
flask_app: Flask, flask_app: Flask,
retrieval_method: RetrievalMethod, retrieval_method: RetrievalMethod,
dataset: Dataset, dataset: Dataset,
all_documents: list[Document],
exceptions: list[str],
query: str | None = None, query: str | None = None,
top_k: int = 4, top_k: int = 4,
score_threshold: float | None = 0.0, score_threshold: float | None = 0.0,
@ -573,8 +590,6 @@ class RetrievalService:
weights: dict | None = None, weights: dict | None = None,
document_ids_filter: list[str] | None = None, document_ids_filter: list[str] | None = None,
attachment_id: str | None = None, attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
): ):
if not query and not attachment_id: if not query and not attachment_id:
return return
@ -696,3 +711,37 @@ class RetrievalService:
} }
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id} return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
if attachment_binding:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"segment_id": attachment_binding.segment_id,
}
)
return attachment_infos

View File

@ -151,20 +151,14 @@ class DatasetRetrieval:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER planning_strategy = PlanningStrategy.ROUTER
available_datasets = [] available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
if not dataset: datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset) available_datasets.append(dataset)
if inputs: if inputs:
inputs = {key: str(value) for key, value in inputs.items()} inputs = {key: str(value) for key, value in inputs.items()}
else: else:
@ -282,26 +276,35 @@ class DatasetRetrieval:
) )
context_files.append(attachment_info) context_files.append(attachment_info)
if show_retrieve_source: if show_retrieve_source:
dataset_ids = [record.segment.dataset_id for record in records]
document_ids = [record.segment.document_id for record in records]
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id.in_(document_ids),
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
dataset_stmt = select(Dataset).where(
Dataset.id.in_(dataset_ids),
)
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset_item = dataset_map.get(segment.dataset_id)
dataset_document_stmt = select(DatasetDocument).where( document_item = document_map.get(segment.document_id)
DatasetDocument.id == segment.document_id, if dataset_item and document_item:
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata( source = RetrievalSourceMetadata(
dataset_id=dataset.id, dataset_id=dataset_item.id,
dataset_name=dataset.name, dataset_name=dataset_item.name,
document_id=document.id, document_id=document_item.id,
document_name=document.name, document_name=document_item.name,
data_source_type=document.data_source_type, data_source_type=document_item.data_source_type,
segment_id=segment.id, segment_id=segment.id,
retriever_from=invoke_from.to_source(), retriever_from=invoke_from.to_source(),
score=record.score or 0.0, score=record.score or 0.0,
doc_metadata=document.doc_metadata, doc_metadata=document_item.doc_metadata,
) )
if invoke_from.to_source() == "dev": if invoke_from.to_source() == "dev":