diff --git a/api/tests/test_containers_integration_tests/.ruff.toml b/api/tests/test_containers_integration_tests/.ruff.toml index 250cf103ab9..68e3f9af4bd 100644 --- a/api/tests/test_containers_integration_tests/.ruff.toml +++ b/api/tests/test_containers_integration_tests/.ruff.toml @@ -9,7 +9,6 @@ extend-select = ["ANN401", "ARG", "TID251"] "models/test_types_enum_text.py" = ["ANN401", "TID251"] "services/test_app_dsl_service.py" = ["ANN401", "TID251", "ARG"] "services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251", "ARG"] -"services/test_hit_testing_service.py" = ["ANN401", "TID251"] "trigger/conftest.py" = ["ANN401", "TID251"] "trigger/test_trigger_e2e.py" = ["ANN401", "TID251", "ARG"] "controllers/console/app/test_app_apis.py" = ["ARG"] diff --git a/api/tests/test_containers_integration_tests/pyrefly.toml b/api/tests/test_containers_integration_tests/pyrefly.toml index 06ea10036f5..92c84320d9a 100644 --- a/api/tests/test_containers_integration_tests/pyrefly.toml +++ b/api/tests/test_containers_integration_tests/pyrefly.toml @@ -100,7 +100,6 @@ project-excludes = [ "services/test_feature_service.py", "services/test_feedback_service.py", "services/test_file_service.py", - "services/test_hit_testing_service.py", "services/test_human_input_delivery_test.py", "services/test_human_input_delivery_test_service.py", "services/test_message_export_service.py", diff --git a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index f332ba05ec2..2d23ae8f68f 100644 --- a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,26 +1,79 @@ from __future__ import annotations import json -from typing import Any, cast -from unittest.mock import ANY, MagicMock, patch +from datetime import datetime +from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from pydantic import BaseModel, ConfigDict, TypeAdapter from sqlalchemy import func, select from sqlalchemy.orm import Session +from core.rag.embedding.retrieval import RetrievalSegments from core.rag.models.document import Document -from models.dataset import Dataset, DatasetQuery +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Document as DatasetDocument +from models.enums import DataSourceType, DocumentCreatedFrom, SegmentStatus from services.hit_testing_service import HitTestingService -def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: - tenant_id = str(uuid4()) - created_by = str(uuid4()) +class _QueryResponse(BaseModel): + content: str + + +class _RetrieveRecordResponse(BaseModel): + content: str | None = None + title: str | None = None + + model_config = ConfigDict(extra="allow") + + +class _RetrieveResponse(BaseModel): + query: _QueryResponse + records: list[_RetrieveRecordResponse] + + +class _DumpedDocumentResponse(BaseModel): + id: str + data_source_type: str + name: str + doc_type: str | None + doc_metadata: dict[str, object] | None + + +class _DumpedSegmentResponse(BaseModel): + id: str + document_id: str + created_at: datetime | None = None + document: _DumpedDocumentResponse | None = None + + model_config = ConfigDict(extra="allow") + + +class _DumpedRetrievalRecordResponse(BaseModel): + segment: _DumpedSegmentResponse + score: float + + model_config = ConfigDict(extra="allow") + + +_DUMPED_RETRIEVAL_RECORDS = TypeAdapter(list[_DumpedRetrievalRecordResponse]) + + +def _create_dataset( + db_session: Session, + *, + provider: str = "vendor", + tenant_id: str | None = None, + created_by: str | None = None, + name: str = "test-dataset", +) -> Dataset: ds = Dataset( - tenant_id=kwargs.get("tenant_id", tenant_id), - name=kwargs.get("name", "test-dataset"), - created_by=kwargs.get("created_by", created_by), + tenant_id=tenant_id or str(uuid4()), + name=name, + created_by=created_by or str(uuid4()), provider=provider, ) db_session.add(ds) @@ -29,36 +82,106 @@ def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: return ds +def _create_dataset_document( + db_session: Session, + *, + name: str = "guide.md", + data_source_type: str = DataSourceType.UPLOAD_FILE, + doc_type: str | None = None, + doc_metadata: dict[str, object] | None = None, +) -> DatasetDocument: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by, + ) + db_session.add(dataset) + db_session.flush() + + document = DatasetDocument( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type=data_source_type, + batch=f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by, + doc_type=doc_type, + doc_metadata=doc_metadata, + ) + db_session.add(document) + db_session.commit() + db_session.refresh(document) + return document + + +def _build_segment( + *, + document_id: str, + tenant_id: str | None = None, + dataset_id: str | None = None, + created_by: str | None = None, +) -> DocumentSegment: + return DocumentSegment( + tenant_id=tenant_id or str(uuid4()), + dataset_id=dataset_id or str(uuid4()), + document_id=document_id, + created_by=created_by or str(uuid4()), + position=1, + content="segment content", + word_count=2, + tokens=2, + status=SegmentStatus.COMPLETED, + ) + + +def _create_segment(db_session: Session, *, document: DatasetDocument | None = None) -> DocumentSegment: + segment = _build_segment( + tenant_id=document.tenant_id if document else None, + dataset_id=document.dataset_id if document else None, + document_id=document.id if document else str(uuid4()), + created_by=document.created_by if document else None, + ) + db_session.add(segment) + db_session.commit() + db_session.refresh(segment) + return segment + + class TestHitTestingService: # ── Utility methods (pure logic, no DB) ──────────────────────────── - def test_escape_query_for_search_should_escape_double_quotes(self): + def test_escape_query_for_search_should_escape_double_quotes(self) -> None: query = 'test "query" with quotes' result = HitTestingService.escape_query_for_search(query) assert result == 'test \\"query\\" with quotes' - def test_hit_testing_args_check_should_pass_with_valid_query(self): + def test_hit_testing_args_check_should_pass_with_valid_query(self) -> None: HitTestingService.hit_testing_args_check({"query": "valid query"}) - def test_hit_testing_args_check_should_pass_with_valid_attachments(self): + def test_hit_testing_args_check_should_pass_with_valid_attachments(self) -> None: HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) - def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): + def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self) -> None: with pytest.raises(ValueError, match="Query or attachment_ids is required"): HitTestingService.hit_testing_args_check({}) - def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): + def test_hit_testing_args_check_should_raise_error_when_query_too_long(self) -> None: with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): HitTestingService.hit_testing_args_check({"query": "a" * 251}) - def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): + def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self) -> None: with pytest.raises(ValueError, match="Attachment_ids must be a list"): HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") - def test_compact_retrieve_response_should_format_correctly(self, mock_format): + def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None: query = "test query" mock_doc = MagicMock(spec=Document) @@ -66,50 +189,49 @@ class TestHitTestingService: mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) + response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc])) - assert cast(dict[str, Any], result["query"])["content"] == query - assert len(result["records"]) == 1 - assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" + assert response.query.content == query + assert len(response.records) == 1 + assert response.records[0].content == "formatted content" mock_format.assert_called_once_with([mock_doc]) def test_compact_external_retrieve_response_should_return_records_for_external_provider( self, db_session_with_containers: Session - ): + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - result = cast( - dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + response = _RetrieveResponse.model_validate( + HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) ) - assert cast(dict[str, Any], result["query"])["content"] == "test query" - assert len(result["records"]) == 2 - assert cast(dict[str, Any], result["records"][0])["content"] == "c1" - assert cast(dict[str, Any], result["records"][1])["title"] == "t2" + assert response.query.content == "test query" + assert len(response.records) == 2 + assert response.records[0].content == "c1" + assert response.records[1].title == "t2" def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( self, db_session_with_containers: Session - ): + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="vendor") - result = cast( - dict[str, Any], - HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + response = _RetrieveResponse.model_validate( + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]) ) - assert cast(dict[str, Any], result["query"])["content"] == "test query" - assert result["records"] == [] + assert response.query.content == "test query" + assert response.records == [] # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") def test_external_retrieve_should_succeed_for_external_provider( - self, mock_ext_retrieve, db_session_with_containers: Session - ): + self, mock_ext_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="external") account_id = str(uuid4()) account = MagicMock() @@ -118,19 +240,18 @@ class TestHitTestingService: before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 - result = cast( - dict[str, Any], + response = _RetrieveResponse.model_validate( HitTestingService.external_retrieve( dataset=dataset, query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, - ), + ) ) - assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' - assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" + assert response.query.content == 'test "query"' + assert response.records[0].content == "ext content" mock_ext_retrieve.assert_called_once_with( dataset_id=dataset.id, query='test \\"query\\"', @@ -142,37 +263,44 @@ class TestHitTestingService: after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + def test_external_retrieve_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) + response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account)) - assert cast(dict[str, Any], result["query"])["content"] == "test query" - assert result["records"] == [] + assert response.query.content == "test query" + assert response.records == [] # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") def test_retrieve_should_use_default_model_when_none_provided( - self, mock_retrieve, db_session_with_containers: Session - ): + self, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None account = MagicMock() account.id = str(uuid4()) - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents + external_retrieval_model: dict[str, object] = {} before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 - result = cast( - dict[str, Any], + response = _RetrieveResponse.model_validate( HitTestingService.retrieve( - dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} - ), + dataset=dataset, + query="test query", + account=account, + retrieval_model=None, + external_retrieval_model=external_retrieval_model, + ) ) - assert cast(dict[str, Any], result["query"])["content"] == "test query" + assert response.query.content == "test query" mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["top_k"] == 4 @@ -183,11 +311,12 @@ class TestHitTestingService: @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") def test_retrieve_should_handle_metadata_filtering( - self, mock_get_meta, mock_retrieve, db_session_with_containers: Session - ): + self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() account.id = str(uuid4()) + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "semantic_search", @@ -197,14 +326,15 @@ class TestHitTestingService: "score_threshold_enabled": False, } mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, + external_retrieval_model=external_retrieval_model, ) mock_get_meta.assert_called_once() @@ -214,10 +344,11 @@ class TestHitTestingService: @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") def test_retrieve_should_return_empty_if_metadata_filtering_fails( - self, mock_get_meta, mock_retrieve, db_session_with_containers: Session - ): + self, mock_get_meta: MagicMock, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "semantic_search", @@ -226,28 +357,31 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - mock_get_meta.return_value = ({}, "condition_string") + empty_document_ids: dict[str, list[str]] = {} + mock_get_meta.return_value = (empty_document_ids, "condition_string") - result = cast( - dict[str, Any], + response = _RetrieveResponse.model_validate( HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, - ), + external_retrieval_model=external_retrieval_model, + ) ) - assert result["records"] == [] + assert response.records == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + def test_retrieve_should_handle_attachments( + self, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() account.id = str(uuid4()) attachment_ids = ["att1", "att2"] + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "semantic_search", @@ -255,19 +389,20 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, + external_retrieval_model=external_retrieval_model, attachment_ids=attachment_ids, ) mock_retrieve.assert_called_once_with( - retrieval_method=ANY, + retrieval_method=RetrievalMethod.SEMANTIC_SEARCH, dataset_id=dataset.id, query="test query", attachment_ids=attachment_ids, @@ -295,10 +430,13 @@ class TestHitTestingService: assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + def test_retrieve_should_handle_reranking_and_threshold( + self, mock_retrieve: MagicMock, db_session_with_containers: Session + ) -> None: dataset = _create_dataset(db_session_with_containers) account = MagicMock() account.id = str(uuid4()) + external_retrieval_model: dict[str, object] = {} retrieval_model = { "search_method": "hybrid_search", @@ -310,14 +448,15 @@ class TestHitTestingService: "score_threshold": 0.5, "weights": {"vector": 0.5, "keyword": 0.5}, } - mock_retrieve.return_value = [] + retrieved_documents: list[Document] = [] + mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( dataset=dataset, query="test query", account=account, retrieval_model=retrieval_model, - external_retrieval_model={}, + external_retrieval_model=external_retrieval_model, ) mock_retrieve.assert_called_once() @@ -326,3 +465,57 @@ class TestHitTestingService: assert kwargs["reranking_model"] == {"provider": "test"} assert kwargs["reranking_mode"] == "weighted_sum" assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5} + + def test_dump_dataset_document_returns_frontend_required_fields(self, db_session_with_containers: Session) -> None: + document = _create_dataset_document(db_session_with_containers, doc_metadata={"source": "manual"}) + + assert HitTestingService._dump_dataset_document(document) == { + "id": document.id, + "data_source_type": "upload_file", + "name": "guide.md", + "doc_type": None, + "doc_metadata": {"source": "manual"}, + } + + def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None: + segment = _build_segment(document_id="") + record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) + + records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + + assert len(records) == 1 + assert records[0].segment.id == segment.id + assert records[0].segment.document_id == "" + assert records[0].score == 0.95 + + def test_dump_retrieval_records_injects_documents(self, db_session_with_containers: Session) -> None: + document = _create_dataset_document(db_session_with_containers) + segment = _create_segment(db_session_with_containers, document=document) + record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9}) + + records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + + assert len(records) == 1 + dumped_segment = records[0].segment + assert dumped_segment.id == segment.id + assert dumped_segment.document_id == document.id + assert dumped_segment.created_at == segment.created_at + assert dumped_segment.document == _DumpedDocumentResponse( + id=document.id, + data_source_type="upload_file", + name="guide.md", + doc_type=None, + doc_metadata=None, + ) + assert records[0].score == 0.9 + + def test_dump_retrieval_records_skips_records_with_missing_documents( + self, db_session_with_containers: Session, caplog: pytest.LogCaptureFixture + ) -> None: + segment = _create_segment(db_session_with_containers) + record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) + + result = HitTestingService._dump_retrieval_records([record]) + + assert result == [] + assert "Skipping hit-testing records with missing documents" in caplog.text diff --git a/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py b/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py deleted file mode 100644 index 5dd0194fd01..00000000000 --- a/api/tests/unit_tests/services/test_hit_testing_service_dump_records.py +++ /dev/null @@ -1,102 +0,0 @@ -from datetime import datetime -from unittest.mock import Mock, patch - -import pytest - -from services.hit_testing_service import HitTestingService - - -def _retrieval_record(payload: dict): - record = Mock() - record.model_dump.return_value = payload - segment = payload.get("segment") - if isinstance(segment, dict): - record.segment = Mock() - record.segment.id = segment.get("id") - record.segment.document_id = segment.get("document_id") - record.segment.created_at = datetime(2024, 1, 1, 0, 0, 0) - else: - record.segment = None - return record - - -def _dataset_document( - document_id: str = "document-1", - name: str = "guide.md", - data_source_type: str = "upload_file", - doc_type: str | None = None, - doc_metadata: dict | None = None, -): - document = Mock() - document.id = document_id - document.name = name - document.data_source_type = data_source_type - document.doc_type = doc_type - document.doc_metadata = doc_metadata - return document - - -class TestHitTestingServiceDumpRecords: - def test_dump_dataset_document_returns_frontend_required_fields(self): - document = _dataset_document(doc_metadata={"source": "manual"}) - - assert HitTestingService._dump_dataset_document(document) == { - "id": "document-1", - "data_source_type": "upload_file", - "name": "guide.md", - "doc_type": None, - "doc_metadata": {"source": "manual"}, - } - - def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self): - record = _retrieval_record({"segment": {"id": "segment-1", "document_id": None}, "score": 0.95}) - record.segment.document_id = None - - assert HitTestingService._dump_retrieval_records([record]) == [ - {"segment": {"id": "segment-1", "document_id": None}, "score": 0.95} - ] - - def test_dump_retrieval_records_injects_documents(self): - record_with_document = _retrieval_record( - { - "segment": { - "id": "segment-1", - "document_id": "document-1", - }, - "score": 0.9, - } - ) - scalars_result = Mock() - scalars_result.all.return_value = [_dataset_document()] - - with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result): - result = HitTestingService._dump_retrieval_records([record_with_document]) - - assert result[0]["segment"]["document"] == { - "id": "document-1", - "data_source_type": "upload_file", - "name": "guide.md", - "doc_type": None, - "doc_metadata": None, - } - - assert result[0]["segment"]["created_at"] == datetime(2024, 1, 1, 0, 0, 0) - - def test_dump_retrieval_records_skips_records_with_missing_documents(self, caplog: pytest.LogCaptureFixture): - record = _retrieval_record( - { - "segment": { - "id": "segment-1", - "document_id": "missing-document", - }, - "score": 0.95, - } - ) - scalars_result = Mock() - scalars_result.all.return_value = [] - - with patch("services.hit_testing_service.db.session.scalars", return_value=scalars_result): - result = HitTestingService._dump_retrieval_records([record]) - - assert result == [] - assert "Skipping hit-testing records with missing documents" in caplog.text