refactor: migrate session.query to select API in delete segment and regenerate summary tasks (#34763)

This commit is contained in:
Renzo 2026-04-08 18:19:03 -05:00 committed by GitHub
parent 1d971d3240
commit 540289e6c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 24 deletions

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -29,12 +29,12 @@ def delete_segment_from_index_task(
start_at = time.perf_counter()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
dataset_document = session.query(Document).where(Document.id == document_id).first()
dataset_document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not dataset_document:
return
@ -60,11 +60,9 @@ def delete_segment_from_index_task(
)
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
segment_attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
).all()
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
@ -77,7 +75,7 @@ def delete_segment_from_index_task(
session.execute(segment_attachment_bind_delete_stmt)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.execute(delete(UploadFile).where(UploadFile.id.in_(attachment_ids)))
session.commit()
end_at = time.perf_counter()

View File

@ -47,7 +47,7 @@ def regenerate_summary_index_task(
try:
with session_factory.create_session() as session:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
@ -84,8 +84,8 @@ def regenerate_summary_index_task(
# For embedding_model change: directly query all segments with existing summaries
# Don't require document indexing_status == "completed"
# Include summaries with status "completed" or "error" (if they have content)
segments_with_summaries = (
session.query(DocumentSegment, DocumentSegmentSummary)
segments_with_summaries = session.execute(
select(DocumentSegment, DocumentSegmentSummary)
.join(
DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
@ -110,8 +110,7 @@ def regenerate_summary_index_task(
DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents
)
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all()
)
).all()
if not segments_with_summaries:
logger.info(
@ -215,8 +214,8 @@ def regenerate_summary_index_task(
try:
# Get all segments with existing summaries
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.join(
DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
@ -229,8 +228,7 @@ def regenerate_summary_index_task(
DocumentSegmentSummary.dataset_id == dataset_id,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if not segments:
continue
@ -245,13 +243,13 @@ def regenerate_summary_index_task(
summary_record = None
try:
# Get existing summary record
summary_record = (
session.query(DocumentSegmentSummary)
.filter_by(
chunk_id=segment.id,
dataset_id=dataset_id,
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset_id,
)
.first()
.limit(1)
)
if not summary_record: