diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index a65069b1b7..635eab73f0 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -592,111 +592,116 @@ class DatasetRetrieval: """Handle retrieval end.""" with flask_app.app_context(): dify_documents = [document for document in documents if document.provider == "dify"] - segment_ids = [] - segment_index_node_ids = [] + if not dify_documents: + self._send_trace_task(message_id, documents, timer) + return + with Session(db.engine) as session: - for document in dify_documents: - if document.metadata is not None: - dataset_document_stmt = select(DatasetDocument).where( - DatasetDocument.id == document.metadata["document_id"] - ) - dataset_document = session.scalar(dataset_document_stmt) - if dataset_document: - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - segment_id = None - if ( - "doc_type" not in document.metadata - or document.metadata.get("doc_type") == DocType.TEXT - ): - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = session.scalar(child_chunk_stmt) - if child_chunk: - segment_id = child_chunk.segment_id - elif ( - "doc_type" in document.metadata - and document.metadata.get("doc_type") == DocType.IMAGE - ): - attachment_info_dict = RetrievalService.get_segment_attachment_info( - dataset_document.dataset_id, - dataset_document.tenant_id, - document.metadata.get("doc_id") or "", - session, - ) - if attachment_info_dict: - segment_id = attachment_info_dict["segment_id"] + # Collect all document_ids and batch fetch DatasetDocuments + document_ids = { + doc.metadata["document_id"] + for doc in dify_documents + if doc.metadata and "document_id" in doc.metadata + } + if not document_ids: + self._send_trace_task(message_id, documents, timer) + return + + dataset_docs_stmt = select(DatasetDocument).where(DatasetDocument.id.in_(document_ids)) + dataset_docs = session.scalars(dataset_docs_stmt).all() + dataset_doc_map = {str(doc.id): doc for doc in dataset_docs} + + # Categorize documents by type and collect necessary IDs + parent_child_text_docs: list[tuple[Document, DatasetDocument]] = [] + parent_child_image_docs: list[tuple[Document, DatasetDocument]] = [] + normal_text_docs: list[tuple[Document, DatasetDocument]] = [] + normal_image_docs: list[tuple[Document, DatasetDocument]] = [] + + for doc in dify_documents: + if not doc.metadata or "document_id" not in doc.metadata: + continue + dataset_doc = dataset_doc_map.get(doc.metadata["document_id"]) + if not dataset_doc: + continue + + is_image = doc.metadata.get("doc_type") == DocType.IMAGE + is_parent_child = dataset_doc.doc_form == IndexStructureType.PARENT_CHILD_INDEX + + if is_parent_child: + if is_image: + parent_child_image_docs.append((doc, dataset_doc)) + else: + parent_child_text_docs.append((doc, dataset_doc)) + else: + if is_image: + normal_image_docs.append((doc, dataset_doc)) + else: + normal_text_docs.append((doc, dataset_doc)) + + segment_ids_to_update: set[str] = set() + + # Process PARENT_CHILD_INDEX text documents - batch fetch ChildChunks + if parent_child_text_docs: + index_node_ids = [doc.metadata["doc_id"] for doc, _ in parent_child_text_docs if doc.metadata] + if index_node_ids: + child_chunks_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(index_node_ids)) + child_chunks = session.scalars(child_chunks_stmt).all() + child_chunk_map = {chunk.index_node_id: chunk.segment_id for chunk in child_chunks} + for doc, _ in parent_child_text_docs: + if doc.metadata: + segment_id = child_chunk_map.get(doc.metadata["doc_id"]) if segment_id: - if segment_id not in segment_ids: - segment_ids.append(segment_id) - _ = ( - session.query(DocumentSegment) - .where(DocumentSegment.id == segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, - ) - ) - else: - query = None - if ( - "doc_type" not in document.metadata - or document.metadata.get("doc_type") == DocType.TEXT - ): - if document.metadata["doc_id"] not in segment_index_node_ids: - segment = ( - session.query(DocumentSegment) - .where(DocumentSegment.index_node_id == document.metadata["doc_id"]) - .first() - ) - if segment: - segment_index_node_ids.append(document.metadata["doc_id"]) - segment_ids.append(segment.id) - query = session.query(DocumentSegment).where( - DocumentSegment.id == segment.id - ) - elif ( - "doc_type" in document.metadata - and document.metadata.get("doc_type") == DocType.IMAGE - ): - attachment_info_dict = RetrievalService.get_segment_attachment_info( - dataset_document.dataset_id, - dataset_document.tenant_id, - document.metadata.get("doc_id") or "", - session, - ) - if attachment_info_dict: - segment_id = attachment_info_dict["segment_id"] - if segment_id not in segment_ids: - segment_ids.append(segment_id) - query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id) - if query: - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.where( - DocumentSegment.dataset_id == document.metadata["dataset_id"] - ) + segment_ids_to_update.add(str(segment_id)) - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, - ) + # Process non-PARENT_CHILD_INDEX text documents - batch fetch DocumentSegments + if normal_text_docs: + index_node_ids = [doc.metadata["doc_id"] for doc, _ in normal_text_docs if doc.metadata] + if index_node_ids: + segments_stmt = select(DocumentSegment).where(DocumentSegment.index_node_id.in_(index_node_ids)) + segments = session.scalars(segments_stmt).all() + segment_map = {seg.index_node_id: seg.id for seg in segments} + for doc, _ in normal_text_docs: + if doc.metadata: + segment_id = segment_map.get(doc.metadata["doc_id"]) + if segment_id: + segment_ids_to_update.add(str(segment_id)) - db.session.commit() + # Process IMAGE documents - batch fetch SegmentAttachmentBindings + all_image_docs = parent_child_image_docs + normal_image_docs + if all_image_docs: + attachment_ids = [ + doc.metadata["doc_id"] + for doc, _ in all_image_docs + if doc.metadata and doc.metadata.get("doc_id") + ] + if attachment_ids: + bindings_stmt = select(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.attachment_id.in_(attachment_ids) + ) + bindings = session.scalars(bindings_stmt).all() + segment_ids_to_update.update(str(binding.segment_id) for binding in bindings) - # get tracing instance - trace_manager: TraceQueueManager | None = ( - self.application_generate_entity.trace_manager if self.application_generate_entity else None - ) - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer + # Batch update hit_count for all segments + if segment_ids_to_update: + session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, ) + session.commit() + + self._send_trace_task(message_id, documents, timer) + + def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None): + """Send trace task if trace manager is available.""" + trace_manager: TraceQueueManager | None = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer ) + ) def _on_query( self,