mirror of https://github.com/langgenius/dify.git
perf: optimize DatasetRetrieval.retrieve、RetrievalService._deduplicat… (#29981)
This commit is contained in:
parent
4d8223d517
commit
eaf4146e2f
|
|
@ -90,13 +90,17 @@ class Jieba(BaseKeyword):
|
|||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
|
||||
segment_query_stmt = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices)
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
|
||||
segments = db.session.execute(segment_query_stmt).scalars().all()
|
||||
segment_map = {segment.index_node_id: segment for segment in segments}
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment_query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
|
||||
)
|
||||
if document_ids_filter:
|
||||
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
|
||||
segment = segment_query.first()
|
||||
segment = segment_map.get(chunk_index)
|
||||
|
||||
if segment:
|
||||
documents.append(
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
|
|
@ -138,37 +139,47 @@ class RetrievalService:
|
|||
|
||||
@classmethod
|
||||
def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]:
|
||||
"""Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search."""
|
||||
"""Deduplicate documents in O(n) while preserving first-seen order.
|
||||
|
||||
Rules:
|
||||
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
|
||||
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
|
||||
- For non-dify documents (or dify without doc_id): deduplicate by content key
|
||||
(provider, page_content), keeping the first occurrence.
|
||||
"""
|
||||
if not documents:
|
||||
return documents
|
||||
|
||||
unique_documents = []
|
||||
seen_doc_ids = set()
|
||||
# Map of dedup key -> chosen Document
|
||||
chosen: dict[tuple, Document] = {}
|
||||
# Preserve the order of first appearance of each dedup key
|
||||
order: list[tuple] = []
|
||||
|
||||
for document in documents:
|
||||
# For dify provider documents, use doc_id for deduplication
|
||||
if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata:
|
||||
doc_id = document.metadata["doc_id"]
|
||||
if doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
unique_documents.append(document)
|
||||
# If duplicate, keep the one with higher score
|
||||
elif "score" in document.metadata:
|
||||
# Find existing document with same doc_id and compare scores
|
||||
for i, existing_doc in enumerate(unique_documents):
|
||||
if (
|
||||
existing_doc.metadata
|
||||
and existing_doc.metadata.get("doc_id") == doc_id
|
||||
and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0)
|
||||
):
|
||||
unique_documents[i] = document
|
||||
break
|
||||
for doc in documents:
|
||||
is_dify = doc.provider == "dify"
|
||||
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
|
||||
|
||||
if is_dify and doc_id:
|
||||
key = ("dify", doc_id)
|
||||
if key not in chosen:
|
||||
chosen[key] = doc
|
||||
order.append(key)
|
||||
else:
|
||||
# Only replace if the new one has a score and it's strictly higher
|
||||
if "score" in doc.metadata:
|
||||
new_score = float(doc.metadata.get("score", 0.0))
|
||||
old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0
|
||||
if new_score > old_score:
|
||||
chosen[key] = doc
|
||||
else:
|
||||
# For non-dify documents, use content-based deduplication
|
||||
if document not in unique_documents:
|
||||
unique_documents.append(document)
|
||||
# Content-based dedup for non-dify or dify without doc_id
|
||||
content_key = (doc.provider or "dify", doc.page_content)
|
||||
if content_key not in chosen:
|
||||
chosen[content_key] = doc
|
||||
order.append(content_key)
|
||||
# If duplicate content appears, we keep the first occurrence (no score comparison)
|
||||
|
||||
return unique_documents
|
||||
return [chosen[k] for k in order]
|
||||
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||
|
|
@ -371,58 +382,96 @@ class RetrievalService:
|
|||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
segment_file_map = {}
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Process documents
|
||||
for document in documents:
|
||||
segment_id = None
|
||||
attachment_info = None
|
||||
child_chunk = None
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
continue
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
else:
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = session.scalar(child_chunk_stmt)
|
||||
dataset_document = dataset_documents[document_id]
|
||||
if not dataset_document:
|
||||
continue
|
||||
valid_dataset_documents[document_id] = dataset_document
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
segment_id = child_chunk.segment_id
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
child_index_node_ids.append(doc_id)
|
||||
else:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
index_node_ids.append(doc_id)
|
||||
|
||||
if not segment_id:
|
||||
continue
|
||||
image_doc_ids = [i for i in image_doc_ids if i]
|
||||
child_index_node_ids = [i for i in child_index_node_ids if i]
|
||||
index_node_ids = [i for i in index_node_ids if i]
|
||||
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
segment_ids = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map = {}
|
||||
child_chunk_map = {}
|
||||
doc_segment_map = {}
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
|
||||
for attachment in attachments:
|
||||
segment_ids.append(attachment["segment_id"])
|
||||
attachment_map[attachment["segment_id"]] = attachment
|
||||
doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"]
|
||||
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
segment_ids.append(i.segment_id)
|
||||
child_chunk_map[i.segment_id] = i
|
||||
doc_segment_map[i.segment_id] = i.index_node_id
|
||||
|
||||
if index_node_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
)
|
||||
segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
|
||||
if index_node_segments:
|
||||
segments.extend(index_node_segments)
|
||||
|
||||
for segment in segments:
|
||||
doc_id = doc_segment_map.get(segment.id)
|
||||
child_chunk = child_chunk_map.get(segment.id)
|
||||
attachment_info = attachment_map.get(segment.id)
|
||||
|
||||
if doc_id:
|
||||
document = doc_to_document_map[doc_id]
|
||||
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(
|
||||
document.metadata.get("document_id")
|
||||
)
|
||||
|
||||
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
if child_chunk:
|
||||
|
|
@ -430,10 +479,10 @@ class RetrievalService:
|
|||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
"position": child_chunk.position,
|
||||
"score": document.metadata.get("score", 0.0),
|
||||
"score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
}
|
||||
map_detail = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
segment_child_map[segment.id] = map_detail
|
||||
|
|
@ -452,13 +501,14 @@ class RetrievalService:
|
|||
"score": document.metadata.get("score", 0.0),
|
||||
}
|
||||
if segment.id in segment_child_map:
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
|
||||
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore
|
||||
segment_child_map[segment.id]["max_score"] = max(
|
||||
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
|
||||
segment_child_map[segment.id]["max_score"],
|
||||
document.metadata.get("score", 0.0) if document else 0.0,
|
||||
)
|
||||
else:
|
||||
segment_child_map[segment.id] = {
|
||||
"max_score": document.metadata.get("score", 0.0),
|
||||
"max_score": document.metadata.get("score", 0.0) if document else 0.0,
|
||||
"child_chunks": [child_chunk_detail],
|
||||
}
|
||||
if attachment_info:
|
||||
|
|
@ -467,46 +517,11 @@ class RetrievalService:
|
|||
else:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
# Handle normal documents
|
||||
segment = None
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
attachment_info_dict = cls.get_segment_attachment_info(
|
||||
dataset_document.dataset_id,
|
||||
dataset_document.tenant_id,
|
||||
document.metadata.get("doc_id") or "",
|
||||
session,
|
||||
)
|
||||
if attachment_info_dict:
|
||||
attachment_info = attachment_info_dict["attachment_info"]
|
||||
segment_id = attachment_info_dict["segment_id"]
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id == segment_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
if segment:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
else:
|
||||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
if segment.id not in include_segment_ids:
|
||||
include_segment_ids.add(segment.id)
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get("score"), # type: ignore
|
||||
"score": document.metadata.get("score", 0.0), # type: ignore
|
||||
}
|
||||
if attachment_info:
|
||||
segment_file_map[segment.id] = [attachment_info]
|
||||
|
|
@ -522,7 +537,7 @@ class RetrievalService:
|
|||
for record in records:
|
||||
if record["segment"].id in segment_child_map:
|
||||
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"]
|
||||
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
|
||||
if record["segment"].id in segment_file_map:
|
||||
record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment]
|
||||
|
||||
|
|
@ -565,6 +580,8 @@ class RetrievalService:
|
|||
flask_app: Flask,
|
||||
retrieval_method: RetrievalMethod,
|
||||
dataset: Dataset,
|
||||
all_documents: list[Document],
|
||||
exceptions: list[str],
|
||||
query: str | None = None,
|
||||
top_k: int = 4,
|
||||
score_threshold: float | None = 0.0,
|
||||
|
|
@ -573,8 +590,6 @@ class RetrievalService:
|
|||
weights: dict | None = None,
|
||||
document_ids_filter: list[str] | None = None,
|
||||
attachment_id: str | None = None,
|
||||
all_documents: list[Document] = [],
|
||||
exceptions: list[str] = [],
|
||||
):
|
||||
if not query and not attachment_id:
|
||||
return
|
||||
|
|
@ -696,3 +711,37 @@ class RetrievalService:
|
|||
}
|
||||
return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id}
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
|
||||
attachment_infos = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||
.all()
|
||||
)
|
||||
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
||||
|
||||
if attachment_bindings:
|
||||
for upload_file in upload_files:
|
||||
attachment_binding = attachment_binding_map.get(upload_file.id)
|
||||
attachment_info = {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"extension": "." + upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"source_url": sign_upload_file(upload_file.id, upload_file.extension),
|
||||
"size": upload_file.size,
|
||||
}
|
||||
if attachment_binding:
|
||||
attachment_infos.append(
|
||||
{
|
||||
"attachment_id": attachment_binding.attachment_id,
|
||||
"attachment_info": attachment_info,
|
||||
"segment_id": attachment_binding.segment_id,
|
||||
}
|
||||
)
|
||||
return attachment_infos
|
||||
|
|
|
|||
|
|
@ -151,20 +151,14 @@ class DatasetRetrieval:
|
|||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
for dataset in datasets:
|
||||
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
|
||||
if inputs:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
else:
|
||||
|
|
@ -282,26 +276,35 @@ class DatasetRetrieval:
|
|||
)
|
||||
context_files.append(attachment_info)
|
||||
if show_retrieve_source:
|
||||
dataset_ids = [record.segment.dataset_id for record in records]
|
||||
document_ids = [record.segment.document_id for record in records]
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id.in_(document_ids),
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore
|
||||
dataset_stmt = select(Dataset).where(
|
||||
Dataset.id.in_(dataset_ids),
|
||||
)
|
||||
datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
dataset_item = dataset_map.get(segment.dataset_id)
|
||||
document_item = document_map.get(segment.document_id)
|
||||
if dataset_item and document_item:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
dataset_id=dataset_item.id,
|
||||
dataset_name=dataset_item.name,
|
||||
document_id=document_item.id,
|
||||
document_name=document_item.name,
|
||||
data_source_type=document_item.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
score=record.score or 0.0,
|
||||
doc_metadata=document.doc_metadata,
|
||||
doc_metadata=document_item.doc_metadata,
|
||||
)
|
||||
|
||||
if invoke_from.to_source() == "dev":
|
||||
|
|
|
|||
Loading…
Reference in New Issue