refactor: migrate session.query to select API in delete segment and regenerate summary tasks (#34763)

This commit is contained in:
Renzo 2026-04-08 18:19:03 -05:00 committed by GitHub
parent 1d971d3240
commit 540289e6c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 24 deletions

View File

@ -3,7 +3,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import delete from sqlalchemy import delete, select
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -29,12 +29,12 @@ def delete_segment_from_index_task(
start_at = time.perf_counter() start_at = time.perf_counter()
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset: if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return return
dataset_document = session.query(Document).where(Document.id == document_id).first() dataset_document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not dataset_document: if not dataset_document:
return return
@ -60,11 +60,9 @@ def delete_segment_from_index_task(
) )
if dataset.is_multimodal: if dataset.is_multimodal:
# delete segment attachment binding # delete segment attachment binding
segment_attachment_bindings = ( segment_attachment_bindings = session.scalars(
session.query(SegmentAttachmentBinding) select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) ).all()
.all()
)
if segment_attachment_bindings: if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
@ -77,7 +75,7 @@ def delete_segment_from_index_task(
session.execute(segment_attachment_bind_delete_stmt) session.execute(segment_attachment_bind_delete_stmt)
# delete upload file # delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) session.execute(delete(UploadFile).where(UploadFile.id.in_(attachment_ids)))
session.commit() session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()

View File

@ -47,7 +47,7 @@ def regenerate_summary_index_task(
try: try:
with session_factory.create_session() as session: with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first() dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset: if not dataset:
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return return
@ -84,8 +84,8 @@ def regenerate_summary_index_task(
# For embedding_model change: directly query all segments with existing summaries # For embedding_model change: directly query all segments with existing summaries
# Don't require document indexing_status == "completed" # Don't require document indexing_status == "completed"
# Include summaries with status "completed" or "error" (if they have content) # Include summaries with status "completed" or "error" (if they have content)
segments_with_summaries = ( segments_with_summaries = session.execute(
session.query(DocumentSegment, DocumentSegmentSummary) select(DocumentSegment, DocumentSegmentSummary)
.join( .join(
DocumentSegmentSummary, DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id, DocumentSegment.id == DocumentSegmentSummary.chunk_id,
@ -110,8 +110,7 @@ def regenerate_summary_index_task(
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
) )
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all() ).all()
)
if not segments_with_summaries: if not segments_with_summaries:
logger.info( logger.info(
@ -215,8 +214,8 @@ def regenerate_summary_index_task(
try: try:
# Get all segments with existing summaries # Get all segments with existing summaries
segments = ( segments = session.scalars(
session.query(DocumentSegment) select(DocumentSegment)
.join( .join(
DocumentSegmentSummary, DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id, DocumentSegment.id == DocumentSegmentSummary.chunk_id,
@ -229,8 +228,7 @@ def regenerate_summary_index_task(
DocumentSegmentSummary.dataset_id == dataset_id, DocumentSegmentSummary.dataset_id == dataset_id,
) )
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
.all() ).all()
)
if not segments: if not segments:
continue continue
@ -245,13 +243,13 @@ def regenerate_summary_index_task(
summary_record = None summary_record = None
try: try:
# Get existing summary record # Get existing summary record
summary_record = ( summary_record = session.scalar(
session.query(DocumentSegmentSummary) select(DocumentSegmentSummary)
.filter_by( .where(
chunk_id=segment.id, DocumentSegmentSummary.chunk_id == segment.id,
dataset_id=dataset_id, DocumentSegmentSummary.dataset_id == dataset_id,
) )
.first() .limit(1)
) )
if not summary_record: if not summary_record: