refactor: migrate session.query to select API in add document and clean document tasks (#34761)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-04-08 18:19:36 -05:00 committed by GitHub
parent 540289e6c6
commit d6d9b04c41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 22 deletions

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete, select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -30,7 +31,9 @@ def add_document_to_index_task(dataset_document_id: str):
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
dataset_document = session.scalar(
select(DatasetDocument).where(DatasetDocument.id == dataset_document_id).limit(1)
)
if not dataset_document:
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
return
@ -45,15 +48,14 @@ def add_document_to_index_task(dataset_document_id: str):
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == SegmentStatus.COMPLETED,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
documents = []
multimodal_documents = []
@ -104,18 +106,15 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
session.query(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
session.execute(
delete(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id)
)
# update segment to enable
session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: naive_utc_now(),
}
session.execute(
update(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id)
.values(enabled=True, disabled_at=None, disabled_by=None, updated_at=naive_utc_now())
)
session.commit()

View File

@ -32,7 +32,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Document has no dataset")
@ -63,7 +63,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
@ -94,7 +94,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if file:
try:
storage.delete(file.key)
@ -124,10 +124,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.execute(
delete(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
)
)
end_at = time.perf_counter()
logger.info(