diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 97052717db..0f19ecadc8 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -90,13 +90,17 @@ class Jieba(BaseKeyword): sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) documents = [] + + segment_query_stmt = db.session.query(DocumentSegment).where( + DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id.in_(sorted_chunk_indices) + ) + if document_ids_filter: + segment_query_stmt = segment_query_stmt.where(DocumentSegment.document_id.in_(document_ids_filter)) + + segments = db.session.execute(segment_query_stmt).scalars().all() + segment_map = {segment.index_node_id: segment for segment in segments} for chunk_index in sorted_chunk_indices: - segment_query = db.session.query(DocumentSegment).where( - DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index - ) - if document_ids_filter: - segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter)) - segment = segment_query.first() + segment = segment_map.get(chunk_index) if segment: documents.append( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index a139fba4d0..9807cb4e6a 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -7,6 +7,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor @@ -138,37 +139,47 @@ class RetrievalService: @classmethod def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: - """Deduplicate documents based on doc_id to avoid duplicate chunks in hybrid search.""" + """Deduplicate documents in O(n) while preserving first-seen order. + + Rules: + - For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest + metadata["score"] among duplicates; if a later duplicate has no score, ignore it. + - For non-dify documents (or dify without doc_id): deduplicate by content key + (provider, page_content), keeping the first occurrence. + """ if not documents: return documents - unique_documents = [] - seen_doc_ids = set() + # Map of dedup key -> chosen Document + chosen: dict[tuple, Document] = {} + # Preserve the order of first appearance of each dedup key + order: list[tuple] = [] - for document in documents: - # For dify provider documents, use doc_id for deduplication - if document.provider == "dify" and document.metadata is not None and "doc_id" in document.metadata: - doc_id = document.metadata["doc_id"] - if doc_id not in seen_doc_ids: - seen_doc_ids.add(doc_id) - unique_documents.append(document) - # If duplicate, keep the one with higher score - elif "score" in document.metadata: - # Find existing document with same doc_id and compare scores - for i, existing_doc in enumerate(unique_documents): - if ( - existing_doc.metadata - and existing_doc.metadata.get("doc_id") == doc_id - and existing_doc.metadata.get("score", 0) < document.metadata.get("score", 0) - ): - unique_documents[i] = document - break + for doc in documents: + is_dify = doc.provider == "dify" + doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None + + if is_dify and doc_id: + key = ("dify", doc_id) + if key not in chosen: + chosen[key] = doc + order.append(key) + else: + # Only replace if the new one has a score and it's strictly higher + if "score" in doc.metadata: + new_score = float(doc.metadata.get("score", 0.0)) + old_score = float(chosen[key].metadata.get("score", 0.0)) if chosen[key].metadata else 0.0 + if new_score > old_score: + chosen[key] = doc else: - # For non-dify documents, use content-based deduplication - if document not in unique_documents: - unique_documents.append(document) + # Content-based dedup for non-dify or dify without doc_id + content_key = (doc.provider or "dify", doc.page_content) + if content_key not in chosen: + chosen[content_key] = doc + order.append(content_key) + # If duplicate content appears, we keep the first occurrence (no score comparison) - return unique_documents + return [chosen[k] for k in order] @classmethod def _get_dataset(cls, dataset_id: str) -> Dataset | None: @@ -371,58 +382,96 @@ class RetrievalService: include_segment_ids = set() segment_child_map = {} segment_file_map = {} - with Session(bind=db.engine, expire_on_commit=False) as session: - # Process documents - for document in documents: - segment_id = None - attachment_info = None - child_chunk = None - document_id = document.metadata.get("document_id") - if document_id not in dataset_documents: - continue - dataset_document = dataset_documents[document_id] - if not dataset_document: - continue + valid_dataset_documents = {} + image_doc_ids = [] + child_index_node_ids = [] + index_node_ids = [] + doc_to_document_map = {} + for document in documents: + document_id = document.metadata.get("document_id") + if document_id not in dataset_documents: + continue - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - # Handle parent-child documents - if document.metadata.get("doc_type") == DocType.IMAGE: - attachment_info_dict = cls.get_segment_attachment_info( - dataset_document.dataset_id, - dataset_document.tenant_id, - document.metadata.get("doc_id") or "", - session, - ) - if attachment_info_dict: - attachment_info = attachment_info_dict["attachment_info"] - segment_id = attachment_info_dict["segment_id"] - else: - child_index_node_id = document.metadata.get("doc_id") - child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) - child_chunk = session.scalar(child_chunk_stmt) + dataset_document = dataset_documents[document_id] + if not dataset_document: + continue + valid_dataset_documents[document_id] = dataset_document - if not child_chunk: - continue - segment_id = child_chunk.segment_id + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + doc_id = document.metadata.get("doc_id") or "" + doc_to_document_map[doc_id] = document + if document.metadata.get("doc_type") == DocType.IMAGE: + image_doc_ids.append(doc_id) + else: + child_index_node_ids.append(doc_id) + else: + doc_id = document.metadata.get("doc_id") or "" + doc_to_document_map[doc_id] = document + if document.metadata.get("doc_type") == DocType.IMAGE: + image_doc_ids.append(doc_id) + else: + index_node_ids.append(doc_id) - if not segment_id: - continue + image_doc_ids = [i for i in image_doc_ids if i] + child_index_node_ids = [i for i in child_index_node_ids if i] + index_node_ids = [i for i in index_node_ids if i] - segment = ( - session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == segment_id, - ) - .first() - ) + segment_ids = [] + index_node_segments: list[DocumentSegment] = [] + segments: list[DocumentSegment] = [] + attachment_map = {} + child_chunk_map = {} + doc_segment_map = {} - if not segment: - continue + with session_factory.create_session() as session: + attachments = cls.get_segment_attachment_infos(image_doc_ids, session) + for attachment in attachments: + segment_ids.append(attachment["segment_id"]) + attachment_map[attachment["segment_id"]] = attachment + doc_segment_map[attachment["segment_id"]] = attachment["attachment_id"] + + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids)) + child_index_nodes = session.execute(child_chunk_stmt).scalars().all() + + for i in child_index_nodes: + segment_ids.append(i.segment_id) + child_chunk_map[i.segment_id] = i + doc_segment_map[i.segment_id] = i.index_node_id + + if index_node_ids: + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id.in_(index_node_ids), + ) + index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore + for index_node_segment in index_node_segments: + doc_segment_map[index_node_segment.id] = index_node_segment.index_node_id + if segment_ids: + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id.in_(segment_ids), + ) + segments = session.execute(document_segment_stmt).scalars().all() # type: ignore + + if index_node_segments: + segments.extend(index_node_segments) + + for segment in segments: + doc_id = doc_segment_map.get(segment.id) + child_chunk = child_chunk_map.get(segment.id) + attachment_info = attachment_map.get(segment.id) + + if doc_id: + document = doc_to_document_map[doc_id] + ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get( + document.metadata.get("document_id") + ) + + if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) if child_chunk: @@ -430,10 +479,10 @@ class RetrievalService: "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), + "score": document.metadata.get("score", 0.0) if document else 0.0, } map_detail = { - "max_score": document.metadata.get("score", 0.0), + "max_score": document.metadata.get("score", 0.0) if document else 0.0, "child_chunks": [child_chunk_detail], } segment_child_map[segment.id] = map_detail @@ -452,13 +501,14 @@ class RetrievalService: "score": document.metadata.get("score", 0.0), } if segment.id in segment_child_map: - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) # type: ignore segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + segment_child_map[segment.id]["max_score"], + document.metadata.get("score", 0.0) if document else 0.0, ) else: segment_child_map[segment.id] = { - "max_score": document.metadata.get("score", 0.0), + "max_score": document.metadata.get("score", 0.0) if document else 0.0, "child_chunks": [child_chunk_detail], } if attachment_info: @@ -467,46 +517,11 @@ class RetrievalService: else: segment_file_map[segment.id] = [attachment_info] else: - # Handle normal documents - segment = None - if document.metadata.get("doc_type") == DocType.IMAGE: - attachment_info_dict = cls.get_segment_attachment_info( - dataset_document.dataset_id, - dataset_document.tenant_id, - document.metadata.get("doc_id") or "", - session, - ) - if attachment_info_dict: - attachment_info = attachment_info_dict["attachment_info"] - segment_id = attachment_info_dict["segment_id"] - document_segment_stmt = select(DocumentSegment).where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == segment_id, - ) - segment = session.scalar(document_segment_stmt) - if segment: - segment_file_map[segment.id] = [attachment_info] - else: - index_node_id = document.metadata.get("doc_id") - if not index_node_id: - continue - document_segment_stmt = select(DocumentSegment).where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - segment = session.scalar(document_segment_stmt) - - if not segment: - continue if segment.id not in include_segment_ids: include_segment_ids.add(segment.id) record = { "segment": segment, - "score": document.metadata.get("score"), # type: ignore + "score": document.metadata.get("score", 0.0), # type: ignore } if attachment_info: segment_file_map[segment.id] = [attachment_info] @@ -522,7 +537,7 @@ class RetrievalService: for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore - record["score"] = segment_child_map[record["segment"].id]["max_score"] + record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore if record["segment"].id in segment_file_map: record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] @@ -565,6 +580,8 @@ class RetrievalService: flask_app: Flask, retrieval_method: RetrievalMethod, dataset: Dataset, + all_documents: list[Document], + exceptions: list[str], query: str | None = None, top_k: int = 4, score_threshold: float | None = 0.0, @@ -573,8 +590,6 @@ class RetrievalService: weights: dict | None = None, document_ids_filter: list[str] | None = None, attachment_id: str | None = None, - all_documents: list[Document] = [], - exceptions: list[str] = [], ): if not query and not attachment_id: return @@ -696,3 +711,37 @@ class RetrievalService: } return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id} return None + + @classmethod + def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]: + attachment_infos = [] + upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + if upload_files: + upload_file_ids = [upload_file.id for upload_file in upload_files] + attachment_bindings = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids)) + .all() + ) + attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings} + + if attachment_bindings: + for upload_file in upload_files: + attachment_binding = attachment_binding_map.get(upload_file.id) + attachment_info = { + "id": upload_file.id, + "name": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "size": upload_file.size, + } + if attachment_binding: + attachment_infos.append( + { + "attachment_id": attachment_binding.attachment_id, + "attachment_info": attachment_info, + "segment_id": attachment_binding.segment_id, + } + ) + return attachment_infos diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 635eab73f0..baf879df95 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -151,20 +151,14 @@ class DatasetRetrieval: if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER available_datasets = [] - for dataset_id in dataset_ids: - # get dataset from dataset id - dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) - dataset = db.session.scalar(dataset_stmt) - # pass if dataset is not available - if not dataset: + dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids)) + datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore + for dataset in datasets: + if dataset.available_document_count == 0 and dataset.provider != "external": continue - - # pass if dataset is not available - if dataset and dataset.available_document_count == 0 and dataset.provider != "external": - continue - available_datasets.append(dataset) + if inputs: inputs = {key: str(value) for key, value in inputs.items()} else: @@ -282,26 +276,35 @@ class DatasetRetrieval: ) context_files.append(attachment_info) if show_retrieve_source: + dataset_ids = [record.segment.dataset_id for record in records] + document_ids = [record.segment.document_id for record in records] + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id.in_(document_ids), + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + documents = db.session.execute(dataset_document_stmt).scalars().all() # type: ignore + dataset_stmt = select(Dataset).where( + Dataset.id.in_(dataset_ids), + ) + datasets = db.session.execute(dataset_stmt).scalars().all() # type: ignore + dataset_map = {i.id: i for i in datasets} + document_map = {i.id: i for i in documents} for record in records: segment = record.segment - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() - dataset_document_stmt = select(DatasetDocument).where( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - document = db.session.scalar(dataset_document_stmt) - if dataset and document: + dataset_item = dataset_map.get(segment.dataset_id) + document_item = document_map.get(segment.document_id) + if dataset_item and document_item: source = RetrievalSourceMetadata( - dataset_id=dataset.id, - dataset_name=dataset.name, - document_id=document.id, - document_name=document.name, - data_source_type=document.data_source_type, + dataset_id=dataset_item.id, + dataset_name=dataset_item.name, + document_id=document_item.id, + document_name=document_item.name, + data_source_type=document_item.data_source_type, segment_id=segment.id, retriever_from=invoke_from.to_source(), score=record.score or 0.0, - doc_metadata=document.doc_metadata, + doc_metadata=document_item.doc_metadata, ) if invoke_from.to_source() == "dev":