From e5fd3133f490c5c26fa0822489d15ba7d93330b0 Mon Sep 17 00:00:00 2001 From: bohdansolovie <153934212+bohdansolovie@users.noreply.github.com> Date: Tue, 14 Apr 2026 09:27:39 -0400 Subject: [PATCH] test: migrate task integration tests to SQLAlchemy 2.0 query APIs (#35170) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../tasks/test_add_document_to_index_task.py | 17 ++--- .../tasks/test_batch_clean_document_task.py | 76 ++++++++++++++----- ...test_batch_create_segment_to_index_task.py | 51 ++++++------- .../tasks/test_dataset_indexing_task.py | 3 +- .../tasks/test_document_indexing_sync_task.py | 53 ++++++------- 5 files changed, 117 insertions(+), 83 deletions(-) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4b04c1accb..fcc15aad42 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Verify logs exist before processing - existing_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + existing_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(existing_logs) == 2 # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify auto disable logs were deleted - remaining_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + remaining_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(remaining_logs) == 0 # Verify index processing occurred normally diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 6cbbe43137..e29ca7ebab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,6 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType @@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_with_image_files( @@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Verify that the task completed successfully by checking the log output @@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify database cleanup db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( @@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document_id).limit(1) + ) assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( @@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_multiple_documents( @@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( @@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None except Exception as e: @@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( @@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Verify initial state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1)) + is not None + ) # Store original IDs for verification document_id = document.id @@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify final database state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id) + ) + == 0 + ) + assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index f9ae33b32f..05827112d4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that segments were created - segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(segments) == 3 # Verify segment content and metadata @@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created (since dataset doesn't exist) - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify no documents were modified - documents = db_session_with_containers.query(Document).all() + documents = db_session_with_containers.scalars(select(Document)).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( @@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) db_session_with_containers.refresh(dataset) - segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments_for_dataset = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( @@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) + ).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( @@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions - all_segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + all_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(all_segments) == 6 # 3 existing + 3 new # Verify position ordering diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 13ea94348a..684097851b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration: def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: """Return the latest persisted document state.""" - return db_session_with_containers.query(Document).where(Document.id == document_id).first() + return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index d4021143ef..d94c1fdf24 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -12,6 +12,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy import delete, func, select, update from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask: """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) - db_session_with_containers.query(Document).where(Document.id == context["document"].id).update( - {"data_source_info": None} + db_session_with_containers.execute( + update(Document).where(Document.id == context["document"].id).values(data_source_info=None) ) db_session_with_containers.commit() @@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR @@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.COMPLETED @@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None @@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask: context = self._create_notion_sync_context(db_session_with_containers) def _delete_dataset_before_clean() -> str: - db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete() + db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id)) db_session_with_containers.commit() return "2024-01-02T00:00:00Z" @@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR