From d6d9b04c416b3575fb9468717922fe3580e4c911 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:19:36 -0500 Subject: [PATCH] refactor: migrate session.query to select API in add document and clean document tasks (#34761) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/tasks/add_document_to_index_task.py | 29 ++++++++++++------------- api/tasks/clean_document_task.py | 16 ++++++++------ 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index ae55c9ee03..c9d4673c0a 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete, select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -30,7 +31,9 @@ def add_document_to_index_task(dataset_document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() + dataset_document = session.scalar( + select(DatasetDocument).where(DatasetDocument.id == dataset_document_id).limit(1) + ) if not dataset_document: logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) return @@ -45,15 +48,14 @@ def add_document_to_index_task(dataset_document_id: str): if not dataset: raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.") - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == SegmentStatus.COMPLETED, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() documents = [] multimodal_documents = [] @@ -104,18 +106,15 @@ def add_document_to_index_task(dataset_document_id: str): index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log - session.query(DatasetAutoDisableLog).where( - DatasetAutoDisableLog.document_id == dataset_document.id - ).delete() + session.execute( + delete(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id) + ) # update segment to enable - session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( - { - DocumentSegment.enabled: True, - DocumentSegment.disabled_at: None, - DocumentSegment.disabled_by: None, - DocumentSegment.updated_at: naive_utc_now(), - } + session.execute( + update(DocumentSegment) + .where(DocumentSegment.document_id == dataset_document.id) + .values(enabled=True, disabled_at=None, disabled_by=None, updated_at=naive_utc_now()) ) session.commit() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index a017e9114b..a657cd553a 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -32,7 +32,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Document has no dataset") @@ -63,7 +63,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if index_node_ids: index_processor = IndexProcessorFactory(doc_form).init_index_processor() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset: index_processor.clean( dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True @@ -94,7 +94,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session, session.begin(): if file_id: - file = session.query(UploadFile).where(UploadFile.id == file_id).first() + file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if file: try: storage.delete(file.key) @@ -124,10 +124,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session, session.begin(): # delete dataset metadata binding - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() + session.execute( + delete(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ) + ) end_at = time.perf_counter() logger.info(