From c91253d05db4add63d072f74a3ce51b05d2d4f1e Mon Sep 17 00:00:00 2001 From: kenwoodjw Date: Fri, 12 Sep 2025 15:29:57 +0800 Subject: [PATCH] fix segment deletion race condition (#24408) Signed-off-by: kenwoodjw Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- .../processor/parent_child_index_processor.py | 36 ++++++++----- api/services/dataset_service.py | 53 ++++++++++++++++--- api/tasks/delete_segment_from_index_task.py | 19 +++++-- 3 files changed, 84 insertions(+), 24 deletions(-) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index d1088af853..514596200f 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -113,21 +113,33 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # node_ids is segment's node_ids if dataset.indexing_technique == "high_quality": delete_child_chunks = kwargs.get("delete_child_chunks") or False + precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids") vector = Vector(dataset) + if node_ids: - child_node_ids = ( - db.session.query(ChildChunk.index_node_id) - .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) - .where( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.index_node_id.in_(node_ids), - ChildChunk.dataset_id == dataset.id, + # Use precomputed child_node_ids if available (to avoid race conditions) + if precomputed_child_node_ids is not None: + child_node_ids = precomputed_child_node_ids + else: + # Fallback to original query (may fail if segments are already deleted) + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() ) - .all() - ) - child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] - vector.delete_by_ids(child_node_ids) - if delete_child_chunks: + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids if child_node_id[0]] + + # Delete from vector index + if child_node_ids: + vector.delete_by_ids(child_node_ids) + + # Delete from database + if delete_child_chunks and child_node_ids: db.session.query(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) ).delete(synchronize_session=False) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 997b28524a..8d1db20b04 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2365,7 +2365,22 @@ class SegmentService: if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) + + # Get child chunk IDs before parent segment is deleted + child_node_ids = [] + if segment.index_node_id: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id == segment.id, + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + db.session.delete(segment) # update document word count assert document.word_count is not None @@ -2375,9 +2390,13 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): - assert isinstance(current_user, Account) - segments = ( - db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) + assert current_user is not None + # Check if segment_ids is not empty to avoid WHERE false condition + if not segment_ids or len(segment_ids) == 0: + return + segments_info = ( + db.session.query(DocumentSegment) + .with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count) .where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -2387,18 +2406,36 @@ class SegmentService: .all() ) - if not segments: + if not segments_info: return - index_node_ids = [seg.index_node_id for seg in segments] - total_words = sum(seg.word_count for seg in segments) + index_node_ids = [info[0] for info in segments_info] + segment_db_ids = [info[1] for info in segments_info] + total_words = sum(info[2] for info in segments_info if info[2] is not None) + + # Get child chunk IDs before parent segments are deleted + child_node_ids = [] + if index_node_ids: + child_chunks = ( + db.session.query(ChildChunk.index_node_id) + .where( + ChildChunk.segment_id.in_(segment_db_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + + # Start async cleanup with both parent and child node IDs + if index_node_ids or child_node_ids: + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) document.word_count = ( document.word_count - total_words if document.word_count and document.word_count > total_words else 0 ) db.session.add(document) - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + # Delete database records db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 0b750cf4db..e8cbd0f250 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -12,7 +12,9 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): +def delete_segment_from_index_task( + index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None +): """ Async Remove segment from index :param index_node_ids: @@ -26,6 +28,7 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume try: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: + logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) return dataset_document = db.session.query(Document).where(Document.id == document_id).first() @@ -33,11 +36,19 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info("Document not in valid state for index operations, skipping") return + doc_form = dataset_document.doc_form - index_type = dataset_document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + # Proceed with index cleanup using the index_node_ids directly + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, + index_node_ids, + with_keywords=True, + delete_child_chunks=True, + precomputed_child_node_ids=child_node_ids, + ) end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))