mirror of
https://github.com/langgenius/dify.git
synced 2026-06-22 19:21:13 +08:00
refactor: accept db.session explicitly in FileService.get_upload_files_by_ids (#37695)
This commit is contained in:
parent
9b4dd9d4e8
commit
a8e3257f43
@ -1778,7 +1778,7 @@ class DocumentService:
|
||||
invalid_source_message="Document does not have an uploaded file to download.",
|
||||
missing_file_message="Uploaded file not found.",
|
||||
)
|
||||
upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
|
||||
upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), document.tenant_id, [upload_file_id])
|
||||
upload_file = upload_files_by_id.get(upload_file_id)
|
||||
if not upload_file:
|
||||
raise NotFound("Uploaded file not found.")
|
||||
@ -1817,7 +1817,7 @@ class DocumentService:
|
||||
upload_file_ids.append(upload_file_id)
|
||||
upload_file_ids_by_document_id[document_id] = upload_file_id
|
||||
|
||||
upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
|
||||
upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), tenant_id, upload_file_ids)
|
||||
missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
|
||||
if missing_upload_file_ids:
|
||||
raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")
|
||||
|
||||
@ -20,7 +20,6 @@ from constants import (
|
||||
VIDEO_EXTENSIONS,
|
||||
)
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -268,7 +267,9 @@ class FileService:
|
||||
session.delete(upload_file)
|
||||
|
||||
@staticmethod
|
||||
def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
|
||||
def get_upload_files_by_ids(
|
||||
session: Session, tenant_id: str, upload_file_ids: Sequence[str]
|
||||
) -> dict[str, UploadFile]:
|
||||
"""
|
||||
Fetch `UploadFile` rows for a tenant in a single batch query.
|
||||
|
||||
@ -282,7 +283,7 @@ class FileService:
|
||||
unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
|
||||
|
||||
# Fetch upload files in one query for efficient batch access.
|
||||
upload_files: Sequence[UploadFile] = db.session.scalars(
|
||||
upload_files: Sequence[UploadFile] = session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
UploadFile.id.in_(unique_upload_file_ids),
|
||||
|
||||
@ -69,7 +69,7 @@ def test_build_upload_files_zip_tempfile_sanitizes_and_dedupes_names(monkeypatch
|
||||
|
||||
def test_get_upload_files_by_ids_returns_empty_when_no_ids(db_session_with_containers: Session) -> None:
|
||||
"""Ensure empty input returns an empty mapping without hitting the database."""
|
||||
assert FileService.get_upload_files_by_ids(str(uuid4()), []) == {}
|
||||
assert FileService.get_upload_files_by_ids(db_session_with_containers, str(uuid4()), []) == {}
|
||||
|
||||
|
||||
def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_containers: Session) -> None:
|
||||
@ -78,7 +78,7 @@ def test_get_upload_files_by_ids_returns_id_keyed_mapping(db_session_with_contai
|
||||
file1 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k1", name="file1.txt")
|
||||
file2 = _create_upload_file(db_session_with_containers, tenant_id=tenant_id, key="k2", name="file2.txt")
|
||||
|
||||
result = FileService.get_upload_files_by_ids(tenant_id, [file1.id, file1.id, file2.id])
|
||||
result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_id, [file1.id, file1.id, file2.id])
|
||||
|
||||
assert set(result.keys()) == {file1.id, file2.id}
|
||||
assert result[file1.id].id == file1.id
|
||||
@ -92,6 +92,6 @@ def test_get_upload_files_by_ids_filters_by_tenant(db_session_with_containers: S
|
||||
file_a = _create_upload_file(db_session_with_containers, tenant_id=tenant_a, key="ka", name="a.txt")
|
||||
_create_upload_file(db_session_with_containers, tenant_id=tenant_b, key="kb", name="b.txt")
|
||||
|
||||
result = FileService.get_upload_files_by_ids(tenant_a, [file_a.id])
|
||||
result = FileService.get_upload_files_by_ids(db_session_with_containers, tenant_a, [file_a.id])
|
||||
|
||||
assert set(result.keys()) == {file_a.id}
|
||||
|
||||
@ -375,19 +375,19 @@ class TestFileService:
|
||||
file_service.delete_file("file_id")
|
||||
# Should return without doing anything
|
||||
|
||||
@patch("services.file_service.db")
|
||||
def test_get_upload_files_by_ids_empty(self, mock_db):
|
||||
result = FileService.get_upload_files_by_ids("tenant_id", [])
|
||||
def test_get_upload_files_by_ids_empty(self):
|
||||
session = MagicMock()
|
||||
result = FileService.get_upload_files_by_ids(session, "tenant_id", [])
|
||||
assert result == {}
|
||||
|
||||
@patch("services.file_service.db")
|
||||
def test_get_upload_files_by_ids(self, mock_db):
|
||||
def test_get_upload_files_by_ids(self):
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "550e8400-e29b-41d4-a716-446655440000"
|
||||
upload_file.tenant_id = "tenant_id"
|
||||
mock_db.session.scalars().all.return_value = [upload_file]
|
||||
session = MagicMock()
|
||||
session.scalars().all.return_value = [upload_file]
|
||||
|
||||
result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"])
|
||||
result = FileService.get_upload_files_by_ids(session, "tenant_id", ["550e8400-e29b-41d4-a716-446655440000"])
|
||||
assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file
|
||||
|
||||
def test_sanitize_zip_entry_name(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user