mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: migrate session.query to select API in retrieval_service (#34638)
This commit is contained in:
parent
1194957fde
commit
72adb5468c
@ -240,7 +240,7 @@ class RetrievalService:
|
||||
@classmethod
|
||||
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
|
||||
with Session(db.engine) as session:
|
||||
return session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
@ -573,15 +573,13 @@ class RetrievalService:
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
@ -851,12 +849,12 @@ class RetrievalService:
|
||||
def get_segment_attachment_info(
|
||||
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
|
||||
) -> SegmentAttachmentResult | None:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
|
||||
if upload_file:
|
||||
attachment_binding = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
attachment_binding = session.scalar(
|
||||
select(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if attachment_binding:
|
||||
attachment_info: AttachmentInfoDict = {
|
||||
@ -875,14 +873,12 @@ class RetrievalService:
|
||||
cls, attachment_ids: list[str], session: Session
|
||||
) -> list[SegmentAttachmentInfoResult]:
|
||||
attachment_infos: list[SegmentAttachmentInfoResult] = []
|
||||
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
|
||||
if upload_files:
|
||||
upload_file_ids = [upload_file.id for upload_file in upload_files]
|
||||
attachment_bindings = (
|
||||
session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||
.all()
|
||||
)
|
||||
attachment_bindings = session.scalars(
|
||||
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
|
||||
).all()
|
||||
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
|
||||
|
||||
if attachment_bindings:
|
||||
|
||||
@ -119,6 +119,14 @@ class _FakeSummaryQuery:
|
||||
return self._summaries
|
||||
|
||||
|
||||
class _FakeScalarsResult:
|
||||
def __init__(self, data: list) -> None:
|
||||
self._data = data
|
||||
|
||||
def all(self) -> list:
|
||||
return self._data
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, execute_payloads: list[list], summaries: list) -> None:
|
||||
self._payloads = list(execute_payloads)
|
||||
@ -128,8 +136,8 @@ class _FakeSession:
|
||||
data = self._payloads.pop(0) if self._payloads else []
|
||||
return _FakeExecuteResult(data)
|
||||
|
||||
def query(self, model):
|
||||
return _FakeSummaryQuery(self._summaries)
|
||||
def scalars(self, stmt):
|
||||
return _FakeScalarsResult(self._summaries)
|
||||
|
||||
|
||||
class _FakeSessionContext:
|
||||
@ -265,14 +273,14 @@ class TestRetrievalServiceInternals:
|
||||
def test_get_dataset_queries_by_id(self, mock_session_class):
|
||||
expected_dataset = Mock(spec=Dataset)
|
||||
mock_session = Mock()
|
||||
mock_session.query.return_value.where.return_value.first.return_value = expected_dataset
|
||||
mock_session.scalar.return_value = expected_dataset
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch.object(retrieval_service_module, "db", SimpleNamespace(engine=Mock())):
|
||||
result = RetrievalService._get_dataset("dataset-123")
|
||||
|
||||
assert result == expected_dataset
|
||||
mock_session.query.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.Keyword")
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
|
||||
@ -1046,12 +1054,8 @@ class TestRetrievalServiceInternals:
|
||||
size=42,
|
||||
)
|
||||
binding = SimpleNamespace(segment_id="segment-1", attachment_id="upload-1")
|
||||
upload_query = Mock()
|
||||
upload_query.where.return_value.first.return_value = upload_file
|
||||
binding_query = Mock()
|
||||
binding_query.where.return_value.first.return_value = binding
|
||||
session = Mock()
|
||||
session.query.side_effect = [upload_query, binding_query]
|
||||
session.scalar.side_effect = [upload_file, binding]
|
||||
|
||||
result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session)
|
||||
|
||||
@ -1076,32 +1080,26 @@ class TestRetrievalServiceInternals:
|
||||
mime_type="image/png",
|
||||
size=42,
|
||||
)
|
||||
upload_query = Mock()
|
||||
upload_query.where.return_value.first.return_value = upload_file
|
||||
binding_query = Mock()
|
||||
binding_query.where.return_value.first.return_value = None
|
||||
session = Mock()
|
||||
session.query.side_effect = [upload_query, binding_query]
|
||||
session.scalar.side_effect = [upload_file, None]
|
||||
|
||||
result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_segment_attachment_info_returns_none_when_upload_file_missing(self):
|
||||
upload_query = Mock()
|
||||
upload_query.where.return_value.first.return_value = None
|
||||
session = Mock()
|
||||
session.query.return_value = upload_query
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_segment_attachment_infos_returns_empty_when_upload_files_missing(self):
|
||||
upload_query = Mock()
|
||||
upload_query.where.return_value.all.return_value = []
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = []
|
||||
session = Mock()
|
||||
session.query.return_value = upload_query
|
||||
session.scalars.return_value = scalars_result
|
||||
|
||||
result = RetrievalService.get_segment_attachment_infos(["upload-1"], session)
|
||||
|
||||
@ -1115,12 +1113,12 @@ class TestRetrievalServiceInternals:
|
||||
mime_type="image/png",
|
||||
size=42,
|
||||
)
|
||||
upload_query = Mock()
|
||||
upload_query.where.return_value.all.return_value = [upload_file]
|
||||
binding_query = Mock()
|
||||
binding_query.where.return_value.all.return_value = []
|
||||
upload_scalars = Mock()
|
||||
upload_scalars.all.return_value = [upload_file]
|
||||
binding_scalars = Mock()
|
||||
binding_scalars.all.return_value = []
|
||||
session = Mock()
|
||||
session.query.side_effect = [upload_query, binding_query]
|
||||
session.scalars.side_effect = [upload_scalars, binding_scalars]
|
||||
|
||||
result = RetrievalService.get_segment_attachment_infos(["upload-1"], session)
|
||||
|
||||
@ -1144,12 +1142,12 @@ class TestRetrievalServiceInternals:
|
||||
)
|
||||
binding = SimpleNamespace(attachment_id="upload-1", segment_id="segment-1")
|
||||
|
||||
upload_query = Mock()
|
||||
upload_query.where.return_value.all.return_value = [upload_file_1, upload_file_2]
|
||||
binding_query = Mock()
|
||||
binding_query.where.return_value.all.return_value = [binding]
|
||||
upload_scalars = Mock()
|
||||
upload_scalars.all.return_value = [upload_file_1, upload_file_2]
|
||||
binding_scalars = Mock()
|
||||
binding_scalars.all.return_value = [binding]
|
||||
session = Mock()
|
||||
session.query.side_effect = [upload_query, binding_query]
|
||||
session.scalars.side_effect = [upload_scalars, binding_scalars]
|
||||
|
||||
result = RetrievalService.get_segment_attachment_infos(["upload-1", "upload-2"], session)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user