fix(dataset): include segment created_at in hit testing response (#37181)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yunlu Wen 2026-06-09 13:15:36 +08:00 committed by GitHub
parent 5bec8eb33a
commit eb3b12fa70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 24 deletions

View File

@ -7,6 +7,7 @@ from sqlalchemy import select
from core.app.app_config.entities import ModelConfig
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -55,17 +56,16 @@ class HitTestingService:
}
@classmethod
def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]:
dumped_records = [record.model_dump() for record in records]
def _dump_retrieval_records(cls, records: list[RetrievalSegments]) -> list[dict[str, Any]]:
document_ids = {
segment.get("document_id")
for record in dumped_records
if isinstance(record, dict)
for segment in [record.get("segment")]
if isinstance(segment, dict) and segment.get("document_id")
document_id
for record in records
if record.segment
for document_id in [record.segment.document_id]
if isinstance(document_id, str) and document_id
}
if not document_ids:
return dumped_records
return [record.model_dump() for record in records]
documents = {
document.id: cls._dump_dataset_document(document)
@ -76,18 +76,23 @@ class HitTestingService:
records_with_documents: list[dict[str, Any]] = []
missing_document_ids: set[str] = set()
for record in dumped_records:
segment = record.get("segment")
if not isinstance(segment, dict):
records_with_documents.append(record)
for retrieval_record in records:
segment = retrieval_record.segment
if not segment or not isinstance(segment.document_id, str) or not segment.document_id:
records_with_documents.append(retrieval_record.model_dump())
continue
document_id = segment.get("document_id")
if document_id in documents:
segment["document"] = documents[document_id]
records_with_documents.append(record)
elif document_id:
document_id = segment.document_id
document = documents.get(document_id)
if document is None:
missing_document_ids.add(document_id)
continue
record = retrieval_record.model_dump()
segment_dict = record["segment"]
segment_dict["created_at"] = segment.created_at
segment_dict["document"] = document
records_with_documents.append(record)
if missing_document_ids:
logger.warning(

View File

@ -1,3 +1,4 @@
from datetime import datetime
from unittest.mock import Mock, patch
from services.hit_testing_service import HitTestingService
@ -6,6 +7,14 @@ from services.hit_testing_service import HitTestingService
def _retrieval_record(payload: dict):
record = Mock()
record.model_dump.return_value = payload
segment = payload.get("segment")
if isinstance(segment, dict):
record.segment = Mock()
record.segment.id = segment.get("id")
record.segment.document_id = segment.get("document_id")
record.segment.created_at = datetime(2024, 1, 1, 0, 0, 0)
else:
record.segment = None
return record
@ -38,12 +47,14 @@ class TestHitTestingServiceDumpRecords:
}
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self):
record = _retrieval_record({"segment": None, "score": 0.95})
record = _retrieval_record({"segment": {"id": "segment-1", "document_id": None}, "score": 0.95})
record.segment.document_id = None
assert HitTestingService._dump_retrieval_records([record]) == [{"segment": None, "score": 0.95}]
assert HitTestingService._dump_retrieval_records([record]) == [
{"segment": {"id": "segment-1", "document_id": None}, "score": 0.95}
]
def test_dump_retrieval_records_injects_documents_and_keeps_non_segment_records(self):
record_without_segment = _retrieval_record({"segment": None, "score": 0.95})
def test_dump_retrieval_records_injects_documents(self):
record_with_document = _retrieval_record(
{
"segment": {
@ -57,10 +68,9 @@ class TestHitTestingServiceDumpRecords:
scalars_result.all.return_value = [_dataset_document()]
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
result = HitTestingService._dump_retrieval_records([record_without_segment, record_with_document])
result = HitTestingService._dump_retrieval_records([record_with_document])
assert result[0] == {"segment": None, "score": 0.95}
assert result[1]["segment"]["document"] == {
assert result[0]["segment"]["document"] == {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
@ -68,6 +78,8 @@ class TestHitTestingServiceDumpRecords:
"doc_metadata": None,
}
assert result[0]["segment"]["created_at"] == datetime(2024, 1, 1, 0, 0, 0)
def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog):
record = _retrieval_record(
{