diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 42c531ae48..6c70e324b1 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -7,6 +7,7 @@ from sqlalchemy import select from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService +from core.rag.embedding.retrieval import RetrievalSegments from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -55,17 +56,16 @@ class HitTestingService: } @classmethod - def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]: - dumped_records = [record.model_dump() for record in records] + def _dump_retrieval_records(cls, records: list[RetrievalSegments]) -> list[dict[str, Any]]: document_ids = { - segment.get("document_id") - for record in dumped_records - if isinstance(record, dict) - for segment in [record.get("segment")] - if isinstance(segment, dict) and segment.get("document_id") + document_id + for record in records + if record.segment + for document_id in [record.segment.document_id] + if isinstance(document_id, str) and document_id } if not document_ids: - return dumped_records + return [record.model_dump() for record in records] documents = { document.id: cls._dump_dataset_document(document) @@ -76,18 +76,23 @@ class HitTestingService: records_with_documents: list[dict[str, Any]] = [] missing_document_ids: set[str] = set() - for record in dumped_records: - segment = record.get("segment") - if not isinstance(segment, dict): - records_with_documents.append(record) + for retrieval_record in records: + segment = retrieval_record.segment + if not segment or not isinstance(segment.document_id, str) or not segment.document_id: + records_with_documents.append(retrieval_record.model_dump()) continue - document_id = segment.get("document_id") - if document_id in documents: - segment["document"] = documents[document_id] - records_with_documents.append(record) - elif document_id: + document_id = segment.document_id + document = documents.get(document_id) + if document is None: missing_document_ids.add(document_id) + continue + + record = retrieval_record.model_dump() + segment_dict = record["segment"] + segment_dict["created_at"] = segment.created_at + segment_dict["document"] = document + records_with_documents.append(record) if missing_document_ids: logger.warning( diff --git a/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py b/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py index f933c5e440..fa0a655204 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py +++ b/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py @@ -1,3 +1,4 @@ +from datetime import datetime from unittest.mock import Mock, patch from services.hit_testing_service import HitTestingService @@ -6,6 +7,14 @@ from services.hit_testing_service import HitTestingService def _retrieval_record(payload: dict): record = Mock() record.model_dump.return_value = payload + segment = payload.get("segment") + if isinstance(segment, dict): + record.segment = Mock() + record.segment.id = segment.get("id") + record.segment.document_id = segment.get("document_id") + record.segment.created_at = datetime(2024, 1, 1, 0, 0, 0) + else: + record.segment = None return record @@ -38,12 +47,14 @@ class TestHitTestingServiceDumpRecords: } def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self): - record = _retrieval_record({"segment": None, "score": 0.95}) + record = _retrieval_record({"segment": {"id": "segment-1", "document_id": None}, "score": 0.95}) + record.segment.document_id = None - assert HitTestingService._dump_retrieval_records([record]) == [{"segment": None, "score": 0.95}] + assert HitTestingService._dump_retrieval_records([record]) == [ + {"segment": {"id": "segment-1", "document_id": None}, "score": 0.95} + ] - def test_dump_retrieval_records_injects_documents_and_keeps_non_segment_records(self): - record_without_segment = _retrieval_record({"segment": None, "score": 0.95}) + def test_dump_retrieval_records_injects_documents(self): record_with_document = _retrieval_record( { "segment": { @@ -57,10 +68,9 @@ class TestHitTestingServiceDumpRecords: scalars_result.all.return_value = [_dataset_document()] with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result): - result = HitTestingService._dump_retrieval_records([record_without_segment, record_with_document]) + result = HitTestingService._dump_retrieval_records([record_with_document]) - assert result[0] == {"segment": None, "score": 0.95} - assert result[1]["segment"]["document"] == { + assert result[0]["segment"]["document"] == { "id": "document-1", "data_source_type": "upload_file", "name": "guide.md", @@ -68,6 +78,8 @@ class TestHitTestingServiceDumpRecords: "doc_metadata": None, } + assert result[0]["segment"]["created_at"] == datetime(2024, 1, 1, 0, 0, 0) + def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog): record = _retrieval_record( {