diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py similarity index 51% rename from api/tests/unit_tests/services/test_hit_testing_service.py rename to api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 80e9729f5b..f332ba05ec 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,239 +1,193 @@ +from __future__ import annotations + import json from typing import Any, cast from unittest.mock import ANY, MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.models.document import Document -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery from services.hit_testing_service import HitTestingService -class TestHitTestingService: - """Test suite for HitTestingService""" +def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + ds = Dataset( + tenant_id=kwargs.get("tenant_id", tenant_id), + name=kwargs.get("name", "test-dataset"), + created_by=kwargs.get("created_by", created_by), + provider=provider, + ) + db_session.add(ds) + db_session.commit() + db_session.refresh(ds) + return ds - # ===== Utility Method Tests ===== + +class TestHitTestingService: + # ── Utility methods (pure logic, no DB) ──────────────────────────── def test_escape_query_for_search_should_escape_double_quotes(self): - """Test that escape_query_for_search escapes double quotes correctly""" - # Arrange query = 'test "query" with quotes' - expected = 'test \\"query\\" with quotes' - - # Act result = HitTestingService.escape_query_for_search(query) - - # Assert - assert result == expected + assert result == 'test \\"query\\" with quotes' def test_hit_testing_args_check_should_pass_with_valid_query(self): - """Test that hit_testing_args_check passes with a valid query""" - # Arrange - args = {"query": "valid query"} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"query": "valid query"}) def test_hit_testing_args_check_should_pass_with_valid_attachments(self): - """Test that hit_testing_args_check passes with valid attachment_ids""" - # Arrange - args = {"attachment_ids": ["id1", "id2"]} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): - """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" - # Arrange - args = {} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query or attachment_ids is required" in str(exc_info.value) + with pytest.raises(ValueError, match="Query or attachment_ids is required"): + HitTestingService.hit_testing_args_check({}) def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): - """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" - # Arrange - args = {"query": "a" * 251} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query cannot exceed 250 characters" in str(exc_info.value) + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): + HitTestingService.hit_testing_args_check({"query": "a" * 251}) def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): - """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" - # Arrange - args = {"attachment_ids": "not a list"} + with pytest.raises(ValueError, match="Attachment_ids must be a list"): + HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Attachment_ids must be a list" in str(exc_info.value) - - # ===== Response Formatting Tests ===== + # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") def test_compact_retrieve_response_should_format_correctly(self, mock_format): - """Test that compact_retrieve_response formats the response correctly""" - # Arrange query = "test query" mock_doc = MagicMock(spec=Document) - documents = [mock_doc] mock_record = MagicMock() mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - # Act - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) - # Assert assert cast(dict[str, Any], result["query"])["content"] == query assert len(result["records"]) == 1 assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" - mock_format.assert_called_once_with(documents) + mock_format.assert_called_once_with([mock_doc]) - def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): - """Test that compact_external_retrieve_response returns records when dataset provider is external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "external" - query = "test query" + def test_compact_external_retrieve_response_should_return_records_for_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert len(result["records"]) == 2 assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][1])["title"] == "t2" - def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): - """Test that compact_external_retrieve_response returns empty records for non-external provider""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" - documents = [{"content": "c1"}] + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="vendor") - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== External Retrieve Tests ===== + # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): - """Test that external_retrieve successfully retrieves from external provider and commits query""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - dataset.provider = "external" - query = 'test "query"' + def test_external_retrieve_should_succeed_for_external_provider( + self, mock_ext_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") + account_id = str(uuid4()) account = MagicMock() - account.id = "account_id" - + account.id = account_id mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.external_retrieve( dataset=dataset, - query=query, + query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" - - # Verify call to RetrievalService.external_retrieve with escaped query mock_ext_retrieve.assert_called_once_with( - dataset_id="dataset_id", + dataset_id=dataset.id, query='test \\"query\\"', external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ) - # Verify DatasetQuery record was added and committed - mock_add.assert_called_once() - mock_commit.assert_called_once() + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self): - """Test that external_retrieve returns empty results immediately if provider is not external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" + def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - # Act - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== Retrieve Tests ===== + # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve uses default model when retrieval_model is not provided""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" + def test_retrieve_should_use_default_model_when_none_provided( + self, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None - query = "test query" account = MagicMock() - account.id = "account_id" - + account.id = str(uuid4()) mock_retrieve.return_value = [] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" mock_retrieve.assert_called_once() - # Verify top_k from default_retrieval_model (4) assert mock_retrieve.call_args.kwargs["top_k"] == 4 - mock_commit.assert_called_once() + + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): - """Test that retrieve correctly calls metadata filtering when conditions are present""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_metadata_filtering( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "semantic_search", @@ -242,29 +196,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response - mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_get_meta.assert_called_once() mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): - """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_return_empty_if_metadata_filtering_fails( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() retrieval_model = { @@ -274,37 +226,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response: condition returned but no IDs mock_get_meta.return_value = ({}, "condition_string") - # Act result = cast( dict[str, Any], HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, ), ) - # Assert assert result["records"] == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) attachment_ids = ["att1", "att2"] retrieval_model = { @@ -315,21 +257,19 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, attachment_ids=attachment_ids, ) - # Assert mock_retrieve.assert_called_once_with( retrieval_method=ANY, - dataset_id="dataset_id", - query=query, + dataset_id=dataset.id, + query="test query", attachment_ids=attachment_ids, top_k=4, score_threshold=0.0, @@ -338,26 +278,27 @@ class TestHitTestingService: weights=None, document_ids_filter=None, ) - # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) - # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) - called_query = mock_add.call_args[0][0] - query_content = json.loads(called_query.content) + + # Verify DatasetQuery was persisted with correct content structure + db_session_with_containers.expire_all() + latest = db_session_with_containers.scalar( + select(DatasetQuery) + .where(DatasetQuery.dataset_id == dataset.id) + .order_by(DatasetQuery.created_at.desc()) + .limit(1) + ) + assert latest is not None + query_content = json.loads(latest.content) assert len(query_content) == 3 # 1 text + 2 images assert query_content[0]["content_type"] == "text_query" assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve passes reranking and threshold parameters correctly""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "hybrid_search", @@ -371,12 +312,14 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_retrieve.assert_called_once() kwargs = mock_retrieve.call_args.kwargs assert kwargs["score_threshold"] == 0.5