mirror of
https://github.com/langgenius/dify.git
synced 2026-06-22 19:21:13 +08:00
test: migrate hit testing dump record tests (#37672)
This commit is contained in:
parent
a7b53b33ee
commit
517b27c2b4
@ -9,7 +9,6 @@ extend-select = ["ANN401", "ARG", "TID251"]
|
||||
"models/test_types_enum_text.py" = ["ANN401", "TID251"]
|
||||
"services/test_app_dsl_service.py" = ["ANN401", "TID251", "ARG"]
|
||||
"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251", "ARG"]
|
||||
"services/test_hit_testing_service.py" = ["ANN401", "TID251"]
|
||||
"trigger/conftest.py" = ["ANN401", "TID251"]
|
||||
"trigger/test_trigger_e2e.py" = ["ANN401", "TID251", "ARG"]
|
||||
"controllers/console/app/test_app_apis.py" = ["ARG"]
|
||||
|
||||
@ -100,7 +100,6 @@ project-excludes = [
|
||||
"services/test_feature_service.py",
|
||||
"services/test_feedback_service.py",
|
||||
"services/test_file_service.py",
|
||||
"services/test_hit_testing_service.py",
|
||||
"services/test_human_input_delivery_test.py",
|
||||
"services/test_human_input_delivery_test_service.py",
|
||||
"services/test_message_export_service.py",
|
||||
|
||||
@ -1,26 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, SegmentStatus
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
class _QueryResponse(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class _RetrieveRecordResponse(BaseModel):
|
||||
content: str | None = None
|
||||
title: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class _RetrieveResponse(BaseModel):
|
||||
query: _QueryResponse
|
||||
records: list[_RetrieveRecordResponse]
|
||||
|
||||
|
||||
class _DumpedDocumentResponse(BaseModel):
|
||||
id: str
|
||||
data_source_type: str
|
||||
name: str
|
||||
doc_type: str | None
|
||||
doc_metadata: dict[str, object] | None
|
||||
|
||||
|
||||
class _DumpedSegmentResponse(BaseModel):
|
||||
id: str
|
||||
document_id: str
|
||||
created_at: datetime | None = None
|
||||
document: _DumpedDocumentResponse | None = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class _DumpedRetrievalRecordResponse(BaseModel):
|
||||
segment: _DumpedSegmentResponse
|
||||
score: float
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
_DUMPED_RETRIEVAL_RECORDS = TypeAdapter(list[_DumpedRetrievalRecordResponse])
|
||||
|
||||
|
||||
def _create_dataset(
|
||||
db_session: Session,
|
||||
*,
|
||||
provider: str = "vendor",
|
||||
tenant_id: str | None = None,
|
||||
created_by: str | None = None,
|
||||
name: str = "test-dataset",
|
||||
) -> Dataset:
|
||||
ds = Dataset(
|
||||
tenant_id=kwargs.get("tenant_id", tenant_id),
|
||||
name=kwargs.get("name", "test-dataset"),
|
||||
created_by=kwargs.get("created_by", created_by),
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
name=name,
|
||||
created_by=created_by or str(uuid4()),
|
||||
provider=provider,
|
||||
)
|
||||
db_session.add(ds)
|
||||
@ -29,36 +82,106 @@ def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs:
|
||||
return ds
|
||||
|
||||
|
||||
def _create_dataset_document(
|
||||
db_session: Session,
|
||||
*,
|
||||
name: str = "guide.md",
|
||||
data_source_type: str = DataSourceType.UPLOAD_FILE,
|
||||
doc_type: str | None = None,
|
||||
doc_metadata: dict[str, object] | None = None,
|
||||
) -> DatasetDocument:
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session.add(dataset)
|
||||
db_session.flush()
|
||||
|
||||
document = DatasetDocument(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type=data_source_type,
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from=DocumentCreatedFrom.WEB,
|
||||
created_by=created_by,
|
||||
doc_type=doc_type,
|
||||
doc_metadata=doc_metadata,
|
||||
)
|
||||
db_session.add(document)
|
||||
db_session.commit()
|
||||
db_session.refresh(document)
|
||||
return document
|
||||
|
||||
|
||||
def _build_segment(
|
||||
*,
|
||||
document_id: str,
|
||||
tenant_id: str | None = None,
|
||||
dataset_id: str | None = None,
|
||||
created_by: str | None = None,
|
||||
) -> DocumentSegment:
|
||||
return DocumentSegment(
|
||||
tenant_id=tenant_id or str(uuid4()),
|
||||
dataset_id=dataset_id or str(uuid4()),
|
||||
document_id=document_id,
|
||||
created_by=created_by or str(uuid4()),
|
||||
position=1,
|
||||
content="segment content",
|
||||
word_count=2,
|
||||
tokens=2,
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
|
||||
|
||||
def _create_segment(db_session: Session, *, document: DatasetDocument | None = None) -> DocumentSegment:
|
||||
segment = _build_segment(
|
||||
tenant_id=document.tenant_id if document else None,
|
||||
dataset_id=document.dataset_id if document else None,
|
||||
document_id=document.id if document else str(uuid4()),
|
||||
created_by=document.created_by if document else None,
|
||||
)
|
||||
db_session.add(segment)
|
||||
db_session.commit()
|
||||
db_session.refresh(segment)
|
||||
return segment
|
||||
|
||||
|
||||
class TestHitTestingService:
|
||||
# ── Utility methods (pure logic, no DB) ────────────────────────────
|
||||
|
||||
def test_escape_query_for_search_should_escape_double_quotes(self):
|
||||
def test_escape_query_for_search_should_escape_double_quotes(self) -> None:
|
||||
query = 'test "query" with quotes'
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
assert result == 'test \\"query\\" with quotes'
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_query(self):
|
||||
def test_hit_testing_args_check_should_pass_with_valid_query(self) -> None:
|
||||
HitTestingService.hit_testing_args_check({"query": "valid query"})
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
|
||||
def test_hit_testing_args_check_should_pass_with_valid_attachments(self) -> None:
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
|
||||
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self) -> None:
|
||||
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
|
||||
HitTestingService.hit_testing_args_check({})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
|
||||
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self) -> None:
|
||||
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check({"query": "a" * 251})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
|
||||
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self) -> None:
|
||||
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
|
||||
|
||||
# ── Response formatting ────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
|
||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None:
|
||||
query = "test query"
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
|
||||
@ -66,50 +189,49 @@ class TestHitTestingService:
|
||||
mock_record.model_dump.return_value = {"content": "formatted content"}
|
||||
mock_format.return_value = [mock_record]
|
||||
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
|
||||
response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc]))
|
||||
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert len(result["records"]) == 1
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
|
||||
assert response.query.content == query
|
||||
assert len(response.records) == 1
|
||||
assert response.records[0].content == "formatted content"
|
||||
mock_format.assert_called_once_with([mock_doc])
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
documents = [
|
||||
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
|
||||
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
|
||||
]
|
||||
|
||||
result = cast(
|
||||
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
|
||||
)
|
||||
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert len(result["records"]) == 2
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
|
||||
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
|
||||
assert response.query.content == "test query"
|
||||
assert len(response.records) == 2
|
||||
assert response.records[0].content == "c1"
|
||||
assert response.records[1].title == "t2"
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}])
|
||||
)
|
||||
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
assert response.query.content == "test query"
|
||||
assert response.records == []
|
||||
|
||||
# ── External retrieve (real DB) ────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
|
||||
def test_external_retrieve_should_succeed_for_external_provider(
|
||||
self, mock_ext_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
self, mock_ext_retrieve: MagicMock, db_session_with_containers: Session
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
account_id = str(uuid4())
|
||||
account = MagicMock()
|
||||
@ -118,19 +240,18 @@ class TestHitTestingService:
|
||||
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query='test "query"',
|
||||
account=account,
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
|
||||
assert response.query.content == 'test "query"'
|
||||
assert response.records[0].content == "ext content"
|
||||
mock_ext_retrieve.assert_called_once_with(
|
||||
dataset_id=dataset.id,
|
||||
query='test \\"query\\"',
|
||||
@ -142,37 +263,44 @@ class TestHitTestingService:
|
||||
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
assert after_count == before_count + 1
|
||||
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
account = MagicMock()
|
||||
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
|
||||
response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account))
|
||||
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
assert response.query.content == "test query"
|
||||
assert response.records == []
|
||||
|
||||
# ── Retrieve (real DB) ─────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
def test_retrieve_should_use_default_model_when_none_provided(
|
||||
self, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
self, mock_retrieve: MagicMock, db_session_with_containers: Session
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
dataset.retrieval_model = None
|
||||
account = MagicMock()
|
||||
account.id = str(uuid4())
|
||||
mock_retrieve.return_value = []
|
||||
retrieved_documents: list[Document] = []
|
||||
mock_retrieve.return_value = retrieved_documents
|
||||
external_retrieval_model: dict[str, object] = {}
|
||||
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
|
||||
),
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=None,
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
)
|
||||
)
|
||||
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert response.query.content == "test query"
|
||||
mock_retrieve.assert_called_once()
|
||||
assert mock_retrieve.call_args.kwargs["top_k"] == 4
|
||||
|
||||
@ -183,11 +311,12 @@ class TestHitTestingService:
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
def test_retrieve_should_handle_metadata_filtering(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = str(uuid4())
|
||||
external_retrieval_model: dict[str, object] = {}
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
@ -197,14 +326,15 @@ class TestHitTestingService:
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
|
||||
mock_retrieve.return_value = []
|
||||
retrieved_documents: list[Document] = []
|
||||
mock_retrieve.return_value = retrieved_documents
|
||||
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
)
|
||||
|
||||
mock_get_meta.assert_called_once()
|
||||
@ -214,10 +344,11 @@ class TestHitTestingService:
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
external_retrieval_model: dict[str, object] = {}
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
@ -226,28 +357,31 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
mock_get_meta.return_value = ({}, "condition_string")
|
||||
empty_document_ids: dict[str, list[str]] = {}
|
||||
mock_get_meta.return_value = (empty_document_ids, "condition_string")
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
),
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
)
|
||||
)
|
||||
|
||||
assert result["records"] == []
|
||||
assert response.records == []
|
||||
mock_retrieve.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
|
||||
def test_retrieve_should_handle_attachments(
|
||||
self, mock_retrieve: MagicMock, db_session_with_containers: Session
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = str(uuid4())
|
||||
attachment_ids = ["att1", "att2"]
|
||||
external_retrieval_model: dict[str, object] = {}
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
@ -255,19 +389,20 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
retrieved_documents: list[Document] = []
|
||||
mock_retrieve.return_value = retrieved_documents
|
||||
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once_with(
|
||||
retrieval_method=ANY,
|
||||
retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
dataset_id=dataset.id,
|
||||
query="test query",
|
||||
attachment_ids=attachment_ids,
|
||||
@ -295,10 +430,13 @@ class TestHitTestingService:
|
||||
assert query_content[1]["content"] == "att1"
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
|
||||
def test_retrieve_should_handle_reranking_and_threshold(
|
||||
self, mock_retrieve: MagicMock, db_session_with_containers: Session
|
||||
) -> None:
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = str(uuid4())
|
||||
external_retrieval_model: dict[str, object] = {}
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "hybrid_search",
|
||||
@ -310,14 +448,15 @@ class TestHitTestingService:
|
||||
"score_threshold": 0.5,
|
||||
"weights": {"vector": 0.5, "keyword": 0.5},
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
retrieved_documents: list[Document] = []
|
||||
mock_retrieve.return_value = retrieved_documents
|
||||
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
)
|
||||
|
||||
mock_retrieve.assert_called_once()
|
||||
@ -326,3 +465,57 @@ class TestHitTestingService:
|
||||
assert kwargs["reranking_model"] == {"provider": "test"}
|
||||
assert kwargs["reranking_mode"] == "weighted_sum"
|
||||
assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5}
|
||||
|
||||
def test_dump_dataset_document_returns_frontend_required_fields(self, db_session_with_containers: Session) -> None:
|
||||
document = _create_dataset_document(db_session_with_containers, doc_metadata={"source": "manual"})
|
||||
|
||||
assert HitTestingService._dump_dataset_document(document) == {
|
||||
"id": document.id,
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": {"source": "manual"},
|
||||
}
|
||||
|
||||
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None:
|
||||
segment = _build_segment(document_id="")
|
||||
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
|
||||
|
||||
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
|
||||
|
||||
assert len(records) == 1
|
||||
assert records[0].segment.id == segment.id
|
||||
assert records[0].segment.document_id == ""
|
||||
assert records[0].score == 0.95
|
||||
|
||||
def test_dump_retrieval_records_injects_documents(self, db_session_with_containers: Session) -> None:
|
||||
document = _create_dataset_document(db_session_with_containers)
|
||||
segment = _create_segment(db_session_with_containers, document=document)
|
||||
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9})
|
||||
|
||||
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
|
||||
|
||||
assert len(records) == 1
|
||||
dumped_segment = records[0].segment
|
||||
assert dumped_segment.id == segment.id
|
||||
assert dumped_segment.document_id == document.id
|
||||
assert dumped_segment.created_at == segment.created_at
|
||||
assert dumped_segment.document == _DumpedDocumentResponse(
|
||||
id=document.id,
|
||||
data_source_type="upload_file",
|
||||
name="guide.md",
|
||||
doc_type=None,
|
||||
doc_metadata=None,
|
||||
)
|
||||
assert records[0].score == 0.9
|
||||
|
||||
def test_dump_retrieval_records_skips_records_with_missing_documents(
|
||||
self, db_session_with_containers: Session, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
segment = _create_segment(db_session_with_containers)
|
||||
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
|
||||
|
||||
result = HitTestingService._dump_retrieval_records([record])
|
||||
|
||||
assert result == []
|
||||
assert "Skipping hit-testing records with missing documents" in caplog.text
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
def _retrieval_record(payload: dict):
|
||||
record = Mock()
|
||||
record.model_dump.return_value = payload
|
||||
segment = payload.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
record.segment = Mock()
|
||||
record.segment.id = segment.get("id")
|
||||
record.segment.document_id = segment.get("document_id")
|
||||
record.segment.created_at = datetime(2024, 1, 1, 0, 0, 0)
|
||||
else:
|
||||
record.segment = None
|
||||
return record
|
||||
|
||||
|
||||
def _dataset_document(
|
||||
document_id: str = "document-1",
|
||||
name: str = "guide.md",
|
||||
data_source_type: str = "upload_file",
|
||||
doc_type: str | None = None,
|
||||
doc_metadata: dict | None = None,
|
||||
):
|
||||
document = Mock()
|
||||
document.id = document_id
|
||||
document.name = name
|
||||
document.data_source_type = data_source_type
|
||||
document.doc_type = doc_type
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
|
||||
|
||||
class TestHitTestingServiceDumpRecords:
|
||||
def test_dump_dataset_document_returns_frontend_required_fields(self):
|
||||
document = _dataset_document(doc_metadata={"source": "manual"})
|
||||
|
||||
assert HitTestingService._dump_dataset_document(document) == {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": {"source": "manual"},
|
||||
}
|
||||
|
||||
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self):
|
||||
record = _retrieval_record({"segment": {"id": "segment-1", "document_id": None}, "score": 0.95})
|
||||
record.segment.document_id = None
|
||||
|
||||
assert HitTestingService._dump_retrieval_records([record]) == [
|
||||
{"segment": {"id": "segment-1", "document_id": None}, "score": 0.95}
|
||||
]
|
||||
|
||||
def test_dump_retrieval_records_injects_documents(self):
|
||||
record_with_document = _retrieval_record(
|
||||
{
|
||||
"segment": {
|
||||
"id": "segment-1",
|
||||
"document_id": "document-1",
|
||||
},
|
||||
"score": 0.9,
|
||||
}
|
||||
)
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = [_dataset_document()]
|
||||
|
||||
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
|
||||
result = HitTestingService._dump_retrieval_records([record_with_document])
|
||||
|
||||
assert result[0]["segment"]["document"] == {
|
||||
"id": "document-1",
|
||||
"data_source_type": "upload_file",
|
||||
"name": "guide.md",
|
||||
"doc_type": None,
|
||||
"doc_metadata": None,
|
||||
}
|
||||
|
||||
assert result[0]["segment"]["created_at"] == datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog: pytest.LogCaptureFixture):
|
||||
record = _retrieval_record(
|
||||
{
|
||||
"segment": {
|
||||
"id": "segment-1",
|
||||
"document_id": "missing-document",
|
||||
},
|
||||
"score": 0.95,
|
||||
}
|
||||
)
|
||||
scalars_result = Mock()
|
||||
scalars_result.all.return_value = []
|
||||
|
||||
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
|
||||
result = HitTestingService._dump_retrieval_records([record])
|
||||
|
||||
assert result == []
|
||||
assert "Skipping hit-testing records with missing documents" in caplog.text
|
||||
Loading…
Reference in New Issue
Block a user