diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 125f3a8e6b8..a8f341fdd04 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1778,7 +1778,7 @@ class DocumentService: invalid_source_message="Document does not have an uploaded file to download.", missing_file_message="Uploaded file not found.", ) - upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id]) + upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), document.tenant_id, [upload_file_id]) upload_file = upload_files_by_id.get(upload_file_id) if not upload_file: raise NotFound("Uploaded file not found.") @@ -1817,7 +1817,7 @@ class DocumentService: upload_file_ids.append(upload_file_id) upload_file_ids_by_document_id[document_id] = upload_file_id - upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids) + upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), tenant_id, upload_file_ids) missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys()) if missing_upload_file_ids: raise NotFound("Only uploaded-file documents can be downloaded as ZIP.") diff --git a/api/services/file_service.py b/api/services/file_service.py index 1781f0c9727..e41d74ad3eb 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -20,7 +20,6 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from graphon.file import helpers as file_helpers @@ -268,7 +267,9 @@ class FileService: session.delete(upload_file) @staticmethod - def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]: + def get_upload_files_by_ids( + session: Session, tenant_id: str, upload_file_ids: Sequence[str] + ) -> dict[str, UploadFile]: """ Fetch `UploadFile` rows for a tenant in a single batch query. @@ -282,7 +283,7 @@ class FileService: unique_upload_file_ids: list[str] = list(set(upload_file_id_list)) # Fetch upload files in one query for efficient batch access. - upload_files: Sequence[UploadFile] = db.session.scalars( + upload_files: Sequence[UploadFile] = session.scalars( select(UploadFile).where( UploadFile.tenant_id == tenant_id, UploadFile.id.in_(unique_upload_file_ids), diff --git a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py index 1101d834a0d..5eb84f805aa 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service_zip_and_lookup.py @@ -69,7 +69,7 @@ def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers: Session) -> None: """Ensure empty input returns an empty mapping without hitting the database.""" - assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {} + assert FileService.get_upload_files_by_ids(db_session_with_containers, str(uuid4()), []) == {} def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers: Session) -> None: @@ -78,7 +78,7 @@ def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_contai file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt") file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt") - result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id]) + result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_id, [file1.id, file1.id, file2.id]) assert set(result.keys()) == {file1.id, file2.id} assert result[file1.id].id == file1.id @@ -92,6 +92,6 @@ def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers: S file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt") _create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt") - result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id]) + result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_a, [file_a.id]) assert set(result.keys()) == {file_a.id} diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index 2e6ca7dbb9c..b81fb823949 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -375,19 +375,19 @@ class TestFileService: file_service.delete_file("file_id") # Should return without doing anything - @patch("services.file_service.db") - def test_get_upload_files_by_ids_empty(self, mock_db): - result = FileService.get_upload_files_by_ids("tenant_id", []) + def test_get_upload_files_by_ids_empty(self): + session = MagicMock() + result = FileService.get_upload_files_by_ids(session, "tenant_id", []) assert result == {} - @patch("services.file_service.db") - def test_get_upload_files_by_ids(self, mock_db): + def test_get_upload_files_by_ids(self): upload_file = MagicMock(spec=UploadFile) upload_file.id = "550e8400-e29b-41d4-a716-446655440000" upload_file.tenant_id = "tenant_id" - mock_db.session.scalars().all.return_value = [upload_file] + session = MagicMock() + session.scalars().all.return_value = [upload_file] - result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"]) + result = FileService.get_upload_files_by_ids(session, "tenant_id", ["550e8400-e29b-41d4-a716-446655440000"]) assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file def test_sanitize_zip_entry_name(self):