From 5821511114e0cfeea0c373375fb840a3f550bba1 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:20:25 -0500 Subject: [PATCH] refactor: migrate session.query to select API in batch clean and disable segments tasks (#34760) --- api/tasks/batch_clean_document_task.py | 18 +++++++---- api/tasks/disable_segments_from_index_task.py | 32 ++++++++----------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 66aafc30b9..56c371fcc1 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -1,9 +1,11 @@ import logging import time +from typing import cast import click from celery import shared_task from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -92,14 +94,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form # ============ Step 3: Delete metadata binding (separate short transaction) ============ try: with session_factory.create_session() as session: - deleted_count = int( - session.query(DatasetMetadataBinding) - .where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ) - .delete(synchronize_session=False) + result = cast( + CursorResult, + session.execute( + delete(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ) + ), ) + deleted_count = result.rowcount session.commit() logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id) except Exception: diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 3cc267e821..86e96ea3f0 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -27,12 +27,12 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) return - dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1)) if not dataset_document: logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) @@ -58,11 +58,9 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_node_ids = [segment.index_node_id for segment in segments] if dataset.is_multimodal: segment_ids = [segment.id for segment in segments] - segment_attachment_bindings = ( - session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() - ) + segment_attachment_bindings = session.scalars( + select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + ).all() if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_node_ids.extend(attachment_ids) @@ -87,16 +85,14 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: # update segment error msg - session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "disabled_at": None, - "disabled_by": None, - "enabled": True, - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .values(disabled_at=None, disabled_by=None, enabled=True) ) session.commit() finally: