test: migrate hit_testing_service tests to testcontainers (#34750)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
volcano303 2026-04-10 10:26:40 +02:00 committed by GitHub
parent 28b8215c9b
commit e224c77920
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,239 +1,193 @@
from __future__ import annotations
import json
from typing import Any, cast
from unittest.mock import ANY, MagicMock, patch
from uuid import uuid4
import pytest
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.rag.models.document import Document
from models.dataset import Dataset
from models.dataset import Dataset, DatasetQuery
from services.hit_testing_service import HitTestingService
class TestHitTestingService:
"""Test suite for HitTestingService"""
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
tenant_id = str(uuid4())
created_by = str(uuid4())
ds = Dataset(
tenant_id=kwargs.get("tenant_id", tenant_id),
name=kwargs.get("name", "test-dataset"),
created_by=kwargs.get("created_by", created_by),
provider=provider,
)
db_session.add(ds)
db_session.commit()
db_session.refresh(ds)
return ds
# ===== Utility Method Tests =====
class TestHitTestingService:
# ── Utility methods (pure logic, no DB) ────────────────────────────
def test_escape_query_for_search_should_escape_double_quotes(self):
"""Test that escape_query_for_search escapes double quotes correctly"""
# Arrange
query = 'test "query" with quotes'
expected = 'test \\"query\\" with quotes'
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == expected
assert result == 'test \\"query\\" with quotes'
def test_hit_testing_args_check_should_pass_with_valid_query(self):
"""Test that hit_testing_args_check passes with a valid query"""
# Arrange
args = {"query": "valid query"}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
HitTestingService.hit_testing_args_check({"query": "valid query"})
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
"""Test that hit_testing_args_check passes with valid attachment_ids"""
# Arrange
args = {"attachment_ids": ["id1", "id2"]}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
# Arrange
args = {}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query or attachment_ids is required" in str(exc_info.value)
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
HitTestingService.hit_testing_args_check({})
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
# Arrange
args = {"query": "a" * 251}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query cannot exceed 250 characters" in str(exc_info.value)
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check({"query": "a" * 251})
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
# Arrange
args = {"attachment_ids": "not a list"}
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Attachment_ids must be a list" in str(exc_info.value)
# ===== Response Formatting Tests =====
# ── Response formatting ────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
"""Test that compact_retrieve_response formats the response correctly"""
# Arrange
query = "test query"
mock_doc = MagicMock(spec=Document)
documents = [mock_doc]
mock_record = MagicMock()
mock_record.model_dump.return_value = {"content": "formatted content"}
mock_format.return_value = [mock_record]
# Act
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 1
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
mock_format.assert_called_once_with(documents)
mock_format.assert_called_once_with([mock_doc])
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "external"
query = "test query"
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
self, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="external")
documents = [
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
]
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
result = cast(
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert len(result["records"]) == 2
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
documents = [{"content": "c1"}]
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
self, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="vendor")
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
result = cast(
dict[str, Any],
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
# ===== External Retrieve Tests =====
# ── External retrieve (real DB) ────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
dataset.provider = "external"
query = 'test "query"'
def test_external_retrieve_should_succeed_for_external_provider(
self, mock_ext_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers, provider="external")
account_id = str(uuid4())
account = MagicMock()
account.id = "account_id"
account.id = account_id
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
# Act
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
HitTestingService.external_retrieve(
dataset=dataset,
query=query,
query='test "query"',
account=account,
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
# Verify call to RetrievalService.external_retrieve with escaped query
mock_ext_retrieve.assert_called_once_with(
dataset_id="dataset_id",
dataset_id=dataset.id,
query='test \\"query\\"',
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
)
# Verify DatasetQuery record was added and committed
mock_add.assert_called_once()
mock_commit.assert_called_once()
db_session_with_containers.expire_all()
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
"""Test that external_retrieve returns empty results immediately if provider is not external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers, provider="vendor")
account = MagicMock()
# Act
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
assert result["records"] == []
# ===== Retrieve Tests =====
# ── Retrieve (real DB) ─────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve uses default model when retrieval_model is not provided"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
def test_retrieve_should_use_default_model_when_none_provided(
self, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
dataset.retrieval_model = None
query = "test query"
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
mock_retrieve.return_value = []
# Act
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["query"])["content"] == "test query"
mock_retrieve.assert_called_once()
# Verify top_k from default_retrieval_model (4)
assert mock_retrieve.call_args.kwargs["top_k"] == 4
mock_commit.assert_called_once()
db_session_with_containers.expire_all()
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
assert after_count == before_count + 1
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_metadata_filtering(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
retrieval_model = {
"search_method": "semantic_search",
@ -242,29 +196,27 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
)
# Assert
mock_get_meta.assert_called_once()
mock_retrieve.assert_called_once()
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
retrieval_model = {
@ -274,37 +226,27 @@ class TestHitTestingService:
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response: condition returned but no IDs
mock_get_meta.return_value = ({}, "condition_string")
# Act
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset,
query=query,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
),
)
# Assert
assert result["records"] == []
mock_retrieve.assert_not_called()
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
attachment_ids = ["att1", "att2"]
retrieval_model = {
@ -315,21 +257,19 @@ class TestHitTestingService:
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset,
query=query,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
attachment_ids=attachment_ids,
)
# Assert
mock_retrieve.assert_called_once_with(
retrieval_method=ANY,
dataset_id="dataset_id",
query=query,
dataset_id=dataset.id,
query="test query",
attachment_ids=attachment_ids,
top_k=4,
score_threshold=0.0,
@ -338,26 +278,27 @@ class TestHitTestingService:
weights=None,
document_ids_filter=None,
)
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
called_query = mock_add.call_args[0][0]
query_content = json.loads(called_query.content)
# Verify DatasetQuery was persisted with correct content structure
db_session_with_containers.expire_all()
latest = db_session_with_containers.scalar(
select(DatasetQuery)
.where(DatasetQuery.dataset_id == dataset.id)
.order_by(DatasetQuery.created_at.desc())
.limit(1)
)
assert latest is not None
query_content = json.loads(latest.content)
assert len(query_content) == 3 # 1 text + 2 images
assert query_content[0]["content_type"] == "text_query"
assert query_content[1]["content_type"] == "image_query"
assert query_content[1]["content"] == "att1"
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve passes reranking and threshold parameters correctly"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
dataset = _create_dataset(db_session_with_containers)
account = MagicMock()
account.id = "account_id"
account.id = str(uuid4())
retrieval_model = {
"search_method": "hybrid_search",
@ -371,12 +312,14 @@ class TestHitTestingService:
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
dataset=dataset,
query="test query",
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
)
# Assert
mock_retrieve.assert_called_once()
kwargs = mock_retrieve.call_args.kwargs
assert kwargs["score_threshold"] == 0.5