refactor: accept db.session explicitly in FileService.get_upload_files_by_ids (#37695)

This commit is contained in:
Rohit Gahlawat 2026-06-21 10:48:28 +05:30 committed by GitHub
parent 9b4dd9d4e8
commit a8e3257f43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 15 deletions

View File

@ -1778,7 +1778,7 @@ class DocumentService:
invalid_source_message="Document does not have an uploaded file to download.",
missing_file_message="Uploaded file not found.",
)
upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), document.tenant_id, [upload_file_id])
upload_file = upload_files_by_id.get(upload_file_id)
if not upload_file:
raise NotFound("Uploaded file not found.")
@ -1817,7 +1817,7 @@ class DocumentService:
upload_file_ids.append(upload_file_id)
upload_file_ids_by_document_id[document_id] = upload_file_id
upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), tenant_id, upload_file_ids)
missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
if missing_upload_file_ids:
raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")

View File

@ -20,7 +20,6 @@ from constants import (
VIDEO_EXTENSIONS,
)
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from extensions.storage.storage_type import StorageType
from graphon.file import helpers as file_helpers
@ -268,7 +267,9 @@ class FileService:
session.delete(upload_file)
@staticmethod
def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
def get_upload_files_by_ids(
session: Session, tenant_id: str, upload_file_ids: Sequence[str]
) -> dict[str, UploadFile]:
"""
Fetch `UploadFile` rows for a tenant in a single batch query.
@ -282,7 +283,7 @@ class FileService:
unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
# Fetch upload files in one query for efficient batch access.
upload_files: Sequence[UploadFile] = db.session.scalars(
upload_files: Sequence[UploadFile] = session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == tenant_id,
UploadFile.id.in_(unique_upload_file_ids),

View File

@ -69,7 +69,7 @@ def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch
def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers: Session) -> None:
"""Ensure empty input returns an empty mapping without hitting the database."""
assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {}
assert FileService.get_upload_files_by_ids(db_session_with_containers, str(uuid4()), []) == {}
def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers: Session) -> None:
@ -78,7 +78,7 @@ def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_contai
file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt")
file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt")
result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id])
result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_id, [file1.id, file1.id, file2.id])
assert set(result.keys()) == {file1.id, file2.id}
assert result[file1.id].id == file1.id
@ -92,6 +92,6 @@ def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers: S
file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt")
_create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt")
result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id])
result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_a, [file_a.id])
assert set(result.keys()) == {file_a.id}

View File

@ -375,19 +375,19 @@ class TestFileService:
file_service.delete_file("file_id")
# Should return without doing anything
@patch("services.file_service.db")
def test_get_upload_files_by_ids_empty(self, mock_db):
result = FileService.get_upload_files_by_ids("tenant_id", [])
def test_get_upload_files_by_ids_empty(self):
session = MagicMock()
result = FileService.get_upload_files_by_ids(session, "tenant_id", [])
assert result == {}
@patch("services.file_service.db")
def test_get_upload_files_by_ids(self, mock_db):
def test_get_upload_files_by_ids(self):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "550e8400-e29b-41d4-a716-446655440000"
upload_file.tenant_id = "tenant_id"
mock_db.session.scalars().all.return_value = [upload_file]
session = MagicMock()
session.scalars().all.return_value = [upload_file]
result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"])
result = FileService.get_upload_files_by_ids(session, "tenant_id", ["550e8400-e29b-41d4-a716-446655440000"])
assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file
def test_sanitize_zip_entry_name(self):