perf: optimize DatasetRetrieval.retrieve、RetrievalService._deduplicat… (#29981)

This commit is contained in:
wangxiaolei 2025-12-22 20:08:21 +08:00 committed by GitHub
parent 4d8223d517
commit eaf4146e2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 201 additions and 145 deletions

View File

@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
segment_query_stmt = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
)
if document_ids_filter:
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
segments = db.session.execute(segment_query_stmt).scalars().all()
segment_map = {segment.index_node_id: segment for segment in segments}
for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
segment = segment_map.get(chunk_index)
if segment:
documents.append(

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
@ -138,37 +139,47 @@ class RetrievalService:
@classmethod
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
"""Deduplicate documents in O(n) while preserving first-seen order.
Rules:
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
- For non-dify documents (or dify without doc_id): deduplicate by content key
(provider, page_content), keeping the first occurrence.
"""
if not documents:
return documents
unique_documents = []
seen_doc_ids = set()
# Map of dedup key -> chosen Document
chosen: dict[tuple, Document] = {}
# Preserve the order of first appearance of each dedup key
order: list[tuple] = []
for document in documents:
# For dify provider documents, use doc_id for deduplication
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
doc_id = document.metadata["doc_id"]
if doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
unique_documents.append(document)
# If duplicate, keep the one with higher score
elif "score" in document.metadata:
# Find existing document with same doc_id and compare scores
for i, existing_doc in enumerate(unique_documents):
if (
existing_doc.metadata
and existing_doc.metadata.get("doc_id") == doc_id
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
):
unique_documents[i] = document
break
for doc in documents:
is_dify = doc.provider == "dify"
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
if is_dify and doc_id:
key = ("dify", doc_id)
if key not in chosen:
chosen[key] = doc
order.append(key)
else:
# Only replace if the new one has a score and it's strictly higher
if "score" in doc.metadata:
new_score = float(doc.metadata.get("score", 0.0))
old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
if new_score > old_score:
chosen[key] = doc
else:
# For non-dify documents, use content-based deduplication
if document not in unique_documents:
unique_documents.append(document)
# Content-based dedup for non-dify or dify without doc_id
content_key = (doc.provider or "dify", doc.page_content)
if content_key not in chosen:
chosen[content_key] = doc
order.append(content_key)
# If duplicate content appears, we keep the first occurrence (no score comparison)
return unique_documents
return [chosen[k] for k in order]
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
@ -371,58 +382,96 @@ class RetrievalService:
include_segment_ids = set()
segment_child_map = {}
segment_file_map = {}
with Session(bind=db.engine, expire_on_commit=False) as session:
# Process documents
for document in documents:
segment_id = None
attachment_info = None
child_chunk = None
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
valid_dataset_documents = {}
image_doc_ids = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
continue
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# Handle parent-child documents
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
else:
child_index_node_id = document.metadata.get("doc_id")
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = session.scalar(child_chunk_stmt)
dataset_document = dataset_documents[document_id]
if not dataset_document:
continue
valid_dataset_documents[document_id] = dataset_document
if not child_chunk:
continue
segment_id = child_chunk.segment_id
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
child_index_node_ids.append(doc_id)
else:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
index_node_ids.append(doc_id)
if not segment_id:
continue
image_doc_ids = [i for i in image_doc_ids if i]
child_index_node_ids = [i for i in child_index_node_ids if i]
index_node_ids = [i for i in index_node_ids if i]
segment = (
session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
.first()
)
segment_ids = []
index_node_segments: list[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map = {}
child_chunk_map = {}
doc_segment_map = {}
if not segment:
continue
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
for attachment in attachments:
segment_ids.append(attachment["segment_id"])
attachment_map[attachment["segment_id"]] = attachment
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
segment_ids.append(i.segment_id)
child_chunk_map[i.segment_id] = i
doc_segment_map[i.segment_id] = i.index_node_id
if index_node_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id.in_(index_node_ids),
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(segment_ids),
)
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
if index_node_segments:
segments.extend(index_node_segments)
for segment in segments:
doc_id = doc_segment_map.get(segment.id)
child_chunk = child_chunk_map.get(segment.id)
attachment_info = attachment_map.get(segment.id)
if doc_id:
document = doc_to_document_map[doc_id]
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
document.metadata.get("document_id")
)
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
if child_chunk:
@ -430,10 +479,10 @@ class RetrievalService:
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
"score": document.metadata.get("score", 0.0) if document else 0.0,
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
@ -452,13 +501,14 @@ class RetrievalService:
"score": document.metadata.get("score", 0.0),
}
if segment.id in segment_child_map:
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
segment_child_map[segment.id]["max_score"],
document.metadata.get("score", 0.0) if document else 0.0,
)
else:
segment_child_map[segment.id] = {
"max_score": document.metadata.get("score", 0.0),
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
"child_chunks": [child_chunk_detail],
}
if attachment_info:
@ -467,46 +517,11 @@ class RetrievalService:
else:
segment_file_map[segment.id] = [attachment_info]
else:
# Handle normal documents
segment = None
if document.metadata.get("doc_type") == DocType.IMAGE:
attachment_info_dict = cls.get_segment_attachment_info(
dataset_document.dataset_id,
dataset_document.tenant_id,
document.metadata.get("doc_id") or "",
session,
)
if attachment_info_dict:
attachment_info = attachment_info_dict["attachment_info"]
segment_id = attachment_info_dict["segment_id"]
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == segment_id,
)
segment = session.scalar(document_segment_stmt)
if segment:
segment_file_map[segment.id] = [attachment_info]
else:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = session.scalar(document_segment_stmt)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score"), # type: ignore
"score": document.metadata.get("score", 0.0), # type: ignore
}
if attachment_info:
segment_file_map[segment.id] = [attachment_info]
@ -522,7 +537,7 @@ class RetrievalService:
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"]
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
if record["segment"].id in segment_file_map:
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
@ -565,6 +580,8 @@ class RetrievalService:
flask_app: Flask,
retrieval_method: RetrievalMethod,
dataset: Dataset,
all_documents: list[Document],
exceptions: list[str],
query: str | None = None,
top_k: int = 4,
score_threshold: float | None = 0.0,
@ -573,8 +590,6 @@ class RetrievalService:
weights: dict | None = None,
document_ids_filter: list[str] | None = None,
attachment_id: str | None = None,
all_documents: list[Document] = [],
exceptions: list[str] = [],
):
if not query and not attachment_id:
return
@ -696,3 +711,37 @@ class RetrievalService:
}
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
return None
@classmethod
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
attachment_infos = []
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:
for upload_file in upload_files:
attachment_binding = attachment_binding_map.get(upload_file.id)
attachment_info = {
"id": upload_file.id,
"name": upload_file.name,
"extension": "." + upload_file.extension,
"mime_type": upload_file.mime_type,
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
"size": upload_file.size,
}
if attachment_binding:
attachment_infos.append(
{
"attachment_id": attachment_binding.attachment_id,
"attachment_info": attachment_info,
"segment_id": attachment_binding.segment_id,
}
)
return attachment_infos

View File

@ -151,20 +151,14 @@ class DatasetRetrieval:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
else:
@ -282,26 +276,35 @@ class DatasetRetrieval:
)
context_files.append(attachment_info)
if show_retrieve_source:
dataset_ids = [record.segment.dataset_id for record in records]
document_ids = [record.segment.document_id for record in records]
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id.in_(document_ids),
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
dataset_stmt = select(Dataset).where(
Dataset.id.in_(dataset_ids),
)
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
dataset_item = dataset_map.get(segment.dataset_id)
document_item = document_map.get(segment.document_id)
if dataset_item and document_item:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
dataset_id=dataset_item.id,
dataset_name=dataset_item.name,
document_id=document_item.id,
document_name=document_item.name,
data_source_type=document_item.data_source_type,
segment_id=segment.id,
retriever_from=invoke_from.to_source(),
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
doc_metadata=document_item.doc_metadata,
)
if invoke_from.to_source() == "dev":