refactor: migrate session.query to select API in document indexing sync task (#34813)

This commit is contained in:
Renzo 2026-04-09 00:44:13 -05:00 committed by GitHub
parent a76a8876d1
commit f5ea61e93e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 13 deletions

View File

@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
tenant_id = None
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.exception("Failed to clean vector index for document %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
document = session.scalar(select(Document).where(Document.id == document_id).limit(1))
if document:
document.indexing_status = IndexingStatus.ERROR
document.error = str(e)

View File

@ -80,7 +80,7 @@ def mock_db_session(mock_document, mock_dataset):
with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory:
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
session.scalar.side_effect = [mock_document, mock_dataset]
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
@ -242,14 +242,13 @@ class TestDataSourceInfoSerialization:
# DB session mock — shared across all ``session_factory.create_session()`` calls
session = MagicMock()
session.scalars.return_value.all.return_value = []
# .where() path: session 1 reads document + dataset, session 2 reads dataset
session.query.return_value.where.return_value.first.side_effect = [
# All .first() calls are now session.scalar() — ordered by call sequence:
# session 1: document + dataset, session 2: dataset (clean), session 3: document (update),
# session 4: document (indexing)
session.scalar.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
]
# .filter_by() path: session 3 (update), session 4 (indexing)
session.query.return_value.filter_by.return_value.first.side_effect = [
mock_document,
mock_document,
]