test: migrate task integration tests to SQLAlchemy 2.0 query APIs (#35170)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
bohdansolovie 2026-04-14 09:27:39 -04:00 committed by GitHub
parent e1bbe57f9c
commit e5fd3133f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 83 deletions

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask:
redis_client.set(indexing_cache_key, "processing", ex=300)
# Verify logs exist before processing
existing_logs = (
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
existing_logs = db_session_with_containers.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
).all()
assert len(existing_logs) == 2
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify auto disable logs were deleted
remaining_logs = (
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
remaining_logs = db_session_with_containers.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id)
).all()
assert len(remaining_logs) == 0
# Verify index processing occurred normally

View File

@ -11,6 +11,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit() # Ensure all changes are committed
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_with_image_files(
@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Verify that the task completed successfully by checking the log output
@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
# Verify database cleanup
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_dataset_not_found(
@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Document should still exist since cleanup failed
existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first()
existing_document = db_session_with_containers.scalar(
select(Document).where(Document.id == document_id).limit(1)
)
assert existing_document is not None
def test_batch_clean_document_task_storage_cleanup_failure(
@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted from database
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted from database
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
def test_batch_clean_document_task_multiple_documents(
@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
assert deleted_file is None
def test_batch_clean_document_task_different_doc_forms(
@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
except Exception as e:
@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Check if the segment still exists (task may have failed before deletion)
existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
existing_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
if existing_segment is not None:
# If segment still exists, the task failed before deletion
# This is acceptable in test environments with external service issues
@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
assert deleted_file is None
def test_batch_clean_document_task_integration_with_real_database(
@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask:
db_session_with_containers.commit()
# Verify initial state
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id)
)
== 3
)
assert (
db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1))
is not None
)
# Store original IDs for verification
document_id = document.id
@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask:
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.scalar(
select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)
)
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
assert deleted_file is None
# Verify final database state
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None
assert (
db_session_with_containers.scalar(
select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id)
)
== 0
)
assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None

View File

@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask:
from extensions.ext_redis import redis_client
# Clear all test data
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.execute(delete(DocumentSegment))
db_session_with_containers.execute(delete(Document))
db_session_with_containers.execute(delete(Dataset))
db_session_with_containers.execute(delete(UploadFile))
db_session_with_containers.execute(delete(TenantAccountJoin))
db_session_with_containers.execute(delete(Tenant))
db_session_with_containers.execute(delete(Account))
db_session_with_containers.commit()
# Clear Redis cache
@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results
# Check that segments were created
segments = (
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
)
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
).all()
assert len(segments) == 3
# Verify segment content and metadata
@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created (since dataset doesn't exist)
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify no documents were modified
documents = db_session_with_containers.query(Document).all()
documents = db_session_with_containers.scalars(select(Document)).all()
assert len(documents) == 0
def test_batch_create_segment_to_index_task_document_not_found(
@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify dataset remains unchanged (no segments were added to the dataset)
db_session_with_containers.refresh(dataset)
segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments_for_dataset = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id)
).all()
assert len(segments_for_dataset) == 0
def test_batch_create_segment_to_index_task_document_not_available(
@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error"
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all()
segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
assert len(segments) == 0
def test_batch_create_segment_to_index_task_upload_file_not_found(
@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify no segments were created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify document remains unchanged
@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify error handling
# Since exception was raised, no segments should be created
segments = db_session_with_containers.query(DocumentSegment).all()
segments = db_session_with_containers.scalars(select(DocumentSegment)).all()
assert len(segments) == 0
# Verify document remains unchanged
@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results
# Check that new segments were created with correct positions
all_segments = (
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
)
all_segments = db_session_with_containers.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position)
).all()
assert len(all_segments) == 6 # 3 existing + 3 new
# Verify position ordering

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration:
def _query_document(self, db_session_with_containers, document_id: str) -> Document | None:
"""Return the latest persisted document state."""
return db_session_with_containers.query(Document).where(Document.id == document_id).first()
return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1))
def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None:
"""Assert all target documents are persisted in parsing status."""

View File

@ -12,6 +12,7 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy import delete, func, select, update
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask:
"""Test that task raises error when data_source_info is empty."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
db_session_with_containers.query(Document).where(Document.id == context["document"].id).update(
{"data_source_info": None}
db_session_with_containers.execute(
update(Document).where(Document.id == context["document"].id).values(data_source_info=None)
)
db_session_with_containers.commit()
@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.ERROR
@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.COMPLETED
@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask:
context = self._create_notion_sync_context(db_session_with_containers)
def _delete_dataset_before_clean() -> str:
db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete()
db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id))
db_session_with_containers.commit()
return "2024-01-02T00:00:00Z"
@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
remaining_segments = db_session_with_containers.scalar(
select(func.count())
.select_from(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.PARSING
@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask:
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
updated_document = db_session_with_containers.scalar(
select(Document).where(Document.id == context["document"].id).limit(1)
)
assert updated_document is not None
assert updated_document.indexing_status == IndexingStatus.ERROR