test: migrate hit testing dump record tests (#37672)

This commit is contained in:
Escape0707 2026-06-20 21:12:14 +09:00 committed by GitHub
parent a7b53b33ee
commit 517b27c2b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 264 additions and 175 deletions

View File

@ -9,7 +9,6 @@ extend-select = ["ANN401", "ARG", "TID251"]
"models/test_types_enum_text.py" = ["ANN401", "TID251"]
"services/test_app_dsl_service.py" = ["ANN401", "TID251", "ARG"]
"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251", "ARG"]
"services/test_hit_testing_service.py" = ["ANN401", "TID251"]
"trigger/conftest.py" = ["ANN401", "TID251"]
"trigger/test_trigger_e2e.py" = ["ANN401", "TID251", "ARG"]
"controllers/console/app/test_app_apis.py" = ["ARG"]

View File

@ -100,7 +100,6 @@ project-excludes = [
"services/test_feature_service.py",
"services/test_feedback_service.py",
"services/test_file_service.py",
"services/test_hit_testing_service.py",
"services/test_human_input_delivery_test.py",
"services/test_human_input_delivery_test_service.py",
"services/test_message_export_service.py",

View File

@ -1,26 +1,79 @@
from __future__ import annotations
import json
from typing import Any, cast
from unittest.mock import ANY, MagicMock, patch
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from pydantic import BaseModel, ConfigDict, TypeAdapter
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.models.document import Document
from models.dataset import Dataset, DatasetQuery
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, DocumentCreatedFrom, SegmentStatus
from services.hit_testing_service import HitTestingService
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
tenant_id = str(uuid4())
created_by = str(uuid4())
class _QueryResponse(BaseModel):
content: str
class _RetrieveRecordResponse(BaseModel):
content: str | None = None
title: str | None = None
model_config = ConfigDict(extra="allow")
class _RetrieveResponse(BaseModel):
query: _QueryResponse
records: list[_RetrieveRecordResponse]
class _DumpedDocumentResponse(BaseModel):
id: str
data_source_type: str
name: str
doc_type: str | None
doc_metadata: dict[str, object] | None
class _DumpedSegmentResponse(BaseModel):
id: str
document_id: str
created_at: datetime | None = None
document: _DumpedDocumentResponse | None = None
model_config = ConfigDict(extra="allow")
class _DumpedRetrievalRecordResponse(BaseModel):
segment: _DumpedSegmentResponse
score: float
model_config = ConfigDict(extra="allow")
_DUMPED_RETRIEVAL_RECORDS = TypeAdapter(list[_DumpedRetrievalRecordResponse])
def _create_dataset(
db_session: Session,
*,
provider: str = "vendor",
tenant_id: str | None = None,
created_by: str | None = None,
name: str = "test-dataset",
) -> Dataset:
ds = Dataset(
tenant_id=kwargs.get("tenant_id", tenant_id),
name=kwargs.get("name", "test-dataset"),
created_by=kwargs.get("created_by", created_by),
tenant_id=tenant_id or str(uuid4()),
name=name,
created_by=created_by or str(uuid4()),
provider=provider,
)
db_session.add(ds)
@ -29,36 +82,106 @@ def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs:
return ds
def _create_dataset_document(
db_session: Session,
*,
name: str = "guide.md",
data_source_type: str = DataSourceType.UPLOAD_FILE,
doc_type: str | None = None,
doc_metadata: dict[str, object] | None = None,
) -> DatasetDocument:
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=created_by,
)
db_session.add(dataset)
db_session.flush()
document = DatasetDocument(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type=data_source_type,
batch=f"batch-{uuid4()}",
name=name,
created_from=DocumentCreatedFrom.WEB,
created_by=created_by,
doc_type=doc_type,
doc_metadata=doc_metadata,
)
db_session.add(document)
db_session.commit()
db_session.refresh(document)
return document
def _build_segment(
*,
document_id: str,
tenant_id: str | None = None,
dataset_id: str | None = None,
created_by: str | None = None,
) -> DocumentSegment:
return DocumentSegment(
tenant_id=tenant_id or str(uuid4()),
dataset_id=dataset_id or str(uuid4()),
document_id=document_id,
created_by=created_by or str(uuid4()),
position=1,
content="segment content",
word_count=2,
tokens=2,
status=SegmentStatus.COMPLETED,
)
def _create_segment(db_session: Session, *, document: DatasetDocument | None = None) -> DocumentSegment:
segment = _build_segment(
tenant_id=document.tenant_id if document else None,
dataset_id=document.dataset_id if document else None,
document_id=document.id if document else str(uuid4()),
created_by=document.created_by if document else None,
)
db_session.add(segment)
db_session.commit()
db_session.refresh(segment)
return segment
class TestHitTestingService:
# ── Utility methods (pure logic, no DB) ────────────────────────────
def test_escape_query_for_search_should_escape_double_quotes(self):
def test_escape_query_for_search_should_escape_double_quotes(self) -> None:
query = 'test "query" with quotes'
result = HitTestingService.escape_query_for_search(query)
assert result == 'test \\"query\\" with quotes'
def test_hit_testing_args_check_should_pass_with_valid_query(self):
def test_hit_testing_args_check_should_pass_with_valid_query(self) -> None:
HitTestingService.hit_testing_args_check({"query": "valid query"})
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
def test_hit_testing_args_check_should_pass_with_valid_attachments(self) -> None:
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self) -> None:
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
HitTestingService.hit_testing_args_check({})
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self) -> None:
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check({"query": "a" * 251})
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self) -> None:
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
# ── Response formatting ────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None:
query = "test query"
mock_doc = MagicMock(spec=Document)
@ -66,50 +189,49 @@ class TestHitTestingService:
mock_record.model_dump.return_value = {"content": "formatted content"}
mock_format.return_value = [mock_record]
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc]))
assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 1
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
assert response.query.content == query
assert len(response.records) == 1
assert response.records[0].content == "formatted content"
mock_format.assert_called_once_with([mock_doc])
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
self, db_session_with_containers: Session
):
) -> None:
dataset = _create_dataset(db_session_with_containers, provider="external")
documents = [
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
]
result = cast(
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
response = _RetrieveResponse.model_validate(
HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
)
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert len(result["records"]) == 2
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
assert response.query.content == "test query"
assert len(response.records) == 2
assert response.records[0].content == "c1"
assert response.records[1].title == "t2"
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
self, db_session_with_containers: Session
):
) -> None:
dataset = _create_dataset(db_session_with_containers, provider="vendor")
result = cast(
dict[str, Any],
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
response = _RetrieveResponse.model_validate(
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}])
)
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
assert response.query.content == "test query"
assert response.records == []
# ── External retrieve (real DB) ────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
def test_external_retrieve_should_succeed_for_external_provider(
self, mock_ext_retrieve, db_session_with_containers: Session
):
self, mock_ext_retrieve: MagicMock, db_session_with_containers: Session
) -> None:
dataset = _create_dataset(db_session_with_containers, provider="external")
account_id = str(uuid4())
account = MagicMock()
@ -118,19 +240,18 @@ class TestHitTestingService:
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
response = _RetrieveResponse.model_validate(
HitTestingService.external_retrieve(
dataset=dataset,
query='test "query"',
account=account,
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
),
)
)
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
assert response.query.content == 'test "query"'
assert response.records[0].content == "ext content"
mock_ext_retrieve.assert_called_once_with(
dataset_id=dataset.id,
query='test \\"query\\"',
@ -142,37 +263,44 @@ class TestHitTestingService:
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
def test_external_retrieve_should_return_empty_for_non_external_provider(
self, db_session_with_containers: Session
) -> None:
dataset = _create_dataset(db_session_with_containers, provider="vendor")
account = MagicMock()
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account))
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
assert response.query.content == "test query"
assert response.records == []
# ── Retrieve (real DB) ─────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
def test_retrieve_should_use_default_model_when_none_provided(
self, mock_retrieve, db_session_with_containers: Session
):
self, mock_retrieve: MagicMock, db_session_with_containers: Session
) -> None:
dataset = _create_dataset(db_session_with_containers)
dataset.retrieval_model = None
account = MagicMock()
account.id = str(uuid4())
mock_retrieve.return_value = []
retrieved_documents: list[Document] = []
mock_retrieve.return_value = retrieved_documents
external_retrieval_model: dict[str, object] = {}
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
response = _RetrieveResponse.model_validate(
HitTestingService.retrieve(
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
),
dataset=dataset,
query="test query",
account=account,
retrieval_model=None,
external_retrieval_model=external_retrieval_model,
)
)
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert response.query.content == "test query"
mock_retrieve.assert_called_once()
assert mock_retrieve.call_args.kwargs["top_k"] == 4
@ -183,11 +311,12 @@ class TestHitTestingService:
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
def test_retrieve_should_handle_metadata_filtering(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session
) -> None:
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = str(uuid4())
external_retrieval_model: dict[str, object] = {}
retrieval_model = {
"search_method": "semantic_search",
@ -197,14 +326,15 @@ class TestHitTestingService:
"score_threshold_enabled": False,
}
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
mock_retrieve.return_value = []
retrieved_documents: list[Document] = []
mock_retrieve.return_value = retrieved_documents
HitTestingService.retrieve(
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
external_retrieval_model=external_retrieval_model,
)
mock_get_meta.assert_called_once()
@ -214,10 +344,11 @@ class TestHitTestingService:
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session
) -> None:
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
external_retrieval_model: dict[str, object] = {}
retrieval_model = {
"search_method": "semantic_search",
@ -226,28 +357,31 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
mock_get_meta.return_value = ({}, "condition_string")
empty_document_ids: dict[str, list[str]] = {}
mock_get_meta.return_value = (empty_document_ids, "condition_string")
result = cast(
dict[str, Any],
response = _RetrieveResponse.model_validate(
HitTestingService.retrieve(
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
),
external_retrieval_model=external_retrieval_model,
)
)
assert result["records"] == []
assert response.records == []
mock_retrieve.assert_not_called()
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
def test_retrieve_should_handle_attachments(
self, mock_retrieve: MagicMock, db_session_with_containers: Session
) -> None:
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = str(uuid4())
attachment_ids = ["att1", "att2"]
external_retrieval_model: dict[str, object] = {}
retrieval_model = {
"search_method": "semantic_search",
@ -255,19 +389,20 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
mock_retrieve.return_value = []
retrieved_documents: list[Document] = []
mock_retrieve.return_value = retrieved_documents
HitTestingService.retrieve(
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
external_retrieval_model=external_retrieval_model,
attachment_ids=attachment_ids,
)
mock_retrieve.assert_called_once_with(
retrieval_method=ANY,
retrieval_method=RetrievalMethod.SEMANTIC_SEARCH,
dataset_id=dataset.id,
query="test query",
attachment_ids=attachment_ids,
@ -295,10 +430,13 @@ class TestHitTestingService:
assert query_content[1]["content"] == "att1"
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
def test_retrieve_should_handle_reranking_and_threshold(
self, mock_retrieve: MagicMock, db_session_with_containers: Session
) -> None:
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = str(uuid4())
external_retrieval_model: dict[str, object] = {}
retrieval_model = {
"search_method": "hybrid_search",
@ -310,14 +448,15 @@ class TestHitTestingService:
"score_threshold": 0.5,
"weights": {"vector": 0.5, "keyword": 0.5},
}
mock_retrieve.return_value = []
retrieved_documents: list[Document] = []
mock_retrieve.return_value = retrieved_documents
HitTestingService.retrieve(
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
external_retrieval_model=external_retrieval_model,
)
mock_retrieve.assert_called_once()
@ -326,3 +465,57 @@ class TestHitTestingService:
assert kwargs["reranking_model"] == {"provider": "test"}
assert kwargs["reranking_mode"] == "weighted_sum"
assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5}
def test_dump_dataset_document_returns_frontend_required_fields(self, db_session_with_containers: Session) -> None:
document = _create_dataset_document(db_session_with_containers, doc_metadata={"source": "manual"})
assert HitTestingService._dump_dataset_document(document) == {
"id": document.id,
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": {"source": "manual"},
}
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None:
segment = _build_segment(document_id="")
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
assert len(records) == 1
assert records[0].segment.id == segment.id
assert records[0].segment.document_id == ""
assert records[0].score == 0.95
def test_dump_retrieval_records_injects_documents(self, db_session_with_containers: Session) -> None:
document = _create_dataset_document(db_session_with_containers)
segment = _create_segment(db_session_with_containers, document=document)
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9})
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
assert len(records) == 1
dumped_segment = records[0].segment
assert dumped_segment.id == segment.id
assert dumped_segment.document_id == document.id
assert dumped_segment.created_at == segment.created_at
assert dumped_segment.document == _DumpedDocumentResponse(
id=document.id,
data_source_type="upload_file",
name="guide.md",
doc_type=None,
doc_metadata=None,
)
assert records[0].score == 0.9
def test_dump_retrieval_records_skips_records_with_missing_documents(
self, db_session_with_containers: Session, caplog: pytest.LogCaptureFixture
) -> None:
segment = _create_segment(db_session_with_containers)
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
result = HitTestingService._dump_retrieval_records([record])
assert result == []
assert "Skipping hit-testing records with missing documents" in caplog.text

View File

@ -1,102 +0,0 @@
from datetime import datetime
from unittest.mock import Mock, patch
import pytest
from services.hit_testing_service import HitTestingService
def _retrieval_record(payload: dict):
record = Mock()
record.model_dump.return_value = payload
segment = payload.get("segment")
if isinstance(segment, dict):
record.segment = Mock()
record.segment.id = segment.get("id")
record.segment.document_id = segment.get("document_id")
record.segment.created_at = datetime(2024, 1, 1, 0, 0, 0)
else:
record.segment = None
return record
def _dataset_document(
document_id: str = "document-1",
name: str = "guide.md",
data_source_type: str = "upload_file",
doc_type: str | None = None,
doc_metadata: dict | None = None,
):
document = Mock()
document.id = document_id
document.name = name
document.data_source_type = data_source_type
document.doc_type = doc_type
document.doc_metadata = doc_metadata
return document
class TestHitTestingServiceDumpRecords:
def test_dump_dataset_document_returns_frontend_required_fields(self):
document = _dataset_document(doc_metadata={"source": "manual"})
assert HitTestingService._dump_dataset_document(document) == {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": {"source": "manual"},
}
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self):
record = _retrieval_record({"segment": {"id": "segment-1", "document_id": None}, "score": 0.95})
record.segment.document_id = None
assert HitTestingService._dump_retrieval_records([record]) == [
{"segment": {"id": "segment-1", "document_id": None}, "score": 0.95}
]
def test_dump_retrieval_records_injects_documents(self):
record_with_document = _retrieval_record(
{
"segment": {
"id": "segment-1",
"document_id": "document-1",
},
"score": 0.9,
}
)
scalars_result = Mock()
scalars_result.all.return_value = [_dataset_document()]
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
result = HitTestingService._dump_retrieval_records([record_with_document])
assert result[0]["segment"]["document"] == {
"id": "document-1",
"data_source_type": "upload_file",
"name": "guide.md",
"doc_type": None,
"doc_metadata": None,
}
assert result[0]["segment"]["created_at"] == datetime(2024, 1, 1, 0, 0, 0)
def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog: pytest.LogCaptureFixture):
record = _retrieval_record(
{
"segment": {
"id": "segment-1",
"document_id": "missing-document",
},
"score": 0.95,
}
)
scalars_result = Mock()
scalars_result.all.return_value = []
with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result):
result = HitTestingService._dump_retrieval_records([record])
assert result == []
assert "Skipping hit-testing records with missing documents" in caplog.text