mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 18:24:09 +08:00
fix(dataset): include segment created_at in hit testing response (#37181)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5bec8eb33a
commit
eb3b12fa70
@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@ -55,17 +56,16 @@ class HitTestingService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _dump_retrieval_records(cls, records: list[Any]) -> list[dict[str, Any]]:
|
||||
dumped_records = [record.model_dump() for record in records]
|
||||
def _dump_retrieval_records(cls, records: list[RetrievalSegments]) -> list[dict[str, Any]]:
|
||||
document_ids = {
|
||||
segment.get("document_id")
|
||||
for record in dumped_records
|
||||
if isinstance(record, dict)
|
||||
for segment in [record.get("segment")]
|
||||
if isinstance(segment, dict) and segment.get("document_id")
|
||||
document_id
|
||||
for record in records
|
||||
if record.segment
|
||||
for document_id in [record.segment.document_id]
|
||||
if isinstance(document_id, str) and document_id
|
||||
}
|
||||
if not document_ids:
|
||||
return dumped_records
|
||||
return [record.model_dump() for record in records]
|
||||
|
||||
documents = {
|
||||
document.id: cls._dump_dataset_document(document)
|
||||
@ -76,18 +76,23 @@ class HitTestingService:
|
||||
|
||||
records_with_documents: list[dict[str, Any]] = []
|
||||
missing_document_ids: set[str] = set()
|
||||
for record in dumped_records:
|
||||
segment = record.get("segment")
|
||||
if not isinstance(segment, dict):
|
||||
records_with_documents.append(record)
|
||||
for retrieval_record in records:
|
||||
segment = retrieval_record.segment
|
||||
if not segment or not isinstance(segment.document_id, str) or not segment.document_id:
|
||||
records_with_documents.append(retrieval_record.model_dump())
|
||||
continue
|
||||
|
||||
document_id = segment.get("document_id")
|
||||
if document_id in documents:
|
||||
segment["document"] = documents[document_id]
|
||||
records_with_documents.append(record)
|
||||
elif document_id:
|
||||
document_id = segment.document_id
|
||||
document = documents.get(document_id)
|
||||
if document is None:
|
||||
missing_document_ids.add(document_id)
|
||||
continue
|
||||
|
||||
record = retrieval_record.model_dump()
|
||||
segment_dict = record["segment"]
|
||||
segment_dict["created_at"] = segment.created_at
|
||||
segment_dict["document"] = document
|
||||
records_with_documents.append(record)
|
||||
|
||||
if missing_document_ids:
|
||||
logger.warning(
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -6,6 +7,14 @@ from services.hit_testing_service import HitTestingService
|
||||
def _retrieval_record(payload: dict):
|
||||
record = Mock()
|
||||
record.model_dump.return_value = payload
|
||||
segment = payload.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
record.segment = Mock()
|
||||
record.segment.id = segment.get("id")
|
||||
record.segment.document_id = segment.get("document_id")
|
||||
record.segment.created_at = datetime(2024, 1, 1, 0, 0, 0)
|
||||
else:
|
||||
record.segment = None
|
||||
return record
|
||||
|
||||
|
||||
@ -38,12 +47,14 @@ class TestHitTestingServiceDumpRecords:
|
||||
}
|
||||
|
||||
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self):
|
||||
record = _retrieval_record({"segment": None, "score": 0.95})
|
||||
record = _retrieval_record({"segment": {"id": "segment-1", "document_id": None}, "score": 0.95})
|
||||
record.segment.document_id = None
|
||||
|
||||
assert HitTestingService._dump_retrieval_records([record]) == [{"segment": None, "score": 0.95}]
|
||||
assert HitTestingService._dump_retrieval_records([record]) == [
|
||||
{"segment": {"id": "segment-1", "document_id": None}, "score": 0.95}
|
||||
]
|
||||
|
||||
def test_dump_retrieval_records_injects_documents_and_keeps_non_segment_records(self):
|
||||
record_without_segment = _retrieval_record({"segment": None, "score": 0.95})
|
||||
def test_dump_retrieval_records_injects_documents(self):
|
||||
record_with_document = _retrieval_record(
|
||||
{
|
||||
"segment": {
|
||||
@ -57,10 +68,9 @@ class TestHitTestingServiceDumpRecords:
|
||||
scalars_result.all.return_value = [_dataset_document()]
|
||||
|
||||
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
|
||||
result = HitTestingService._dump_retrieval_records([record_without_segment, record_with_document])
|
||||
result = HitTestingService._dump_retrieval_records([record_with_document])
|
||||
|
||||
assert result[0] == {"segment": None, "score": 0.95}
|
||||
assert result[1]["segment"]["document"] == {
|
||||
assert result[0]["segment"]["document"] == {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
@ -68,6 +78,8 @@ class TestHitTestingServiceDumpRecords:
|
||||
"doc_metadata": None,
|
||||
}
|
||||
|
||||
assert result[0]["segment"]["created_at"] == datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog):
|
||||
record = _retrieval_record(
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user