refactor: migrate session.query to select API in batch clean and disable segments tasks (#34760)

This commit is contained in:
Renzo 2026-04-08 18:20:25 -05:00 committed by GitHub
parent d6d9b04c41
commit 5821511114
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 25 deletions

View File

@ -1,9 +1,11 @@
import logging
import time
from typing import cast
import click
from celery import shared_task
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -92,14 +94,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = int(
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
result = cast(
CursorResult,
session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
),
)
deleted_count = result.rowcount
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
except Exception:

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -27,12 +27,12 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1))
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
@ -58,11 +58,9 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
segment_attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
).all()
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
@ -87,16 +85,14 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
session.execute(
update(DocumentSegment)
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.values(disabled_at=None, disabled_by=None, enabled=True)
)
session.commit()
finally: