refactor: migrate session.query to select API in deal dataset vector index task (#34819)

This commit is contained in:
Renzo 2026-04-09 00:50:59 -05:00 committed by GitHub
parent ee789db443
commit 9a51c2f56a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
for segment in segments:
@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
elif action == "update":
@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
multimodal_documents = []
@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
else: