refactor: migrate session.query to select API in deal dataset index update task (#34847)

This commit is contained in:
Renzo 2026-04-09 09:17:08 -05:00 committed by GitHub
parent e143dbce50
commit 75b88a5416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task # type: ignore
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
session.query(DatasetDocument)
.where(
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
for segment in segments:
@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
elif action == "update":
dataset_documents = (
session.query(DatasetDocument)
.where(
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
multimodal_documents = []
@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
else: