diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index b985ebbe1d..16fb7ba146 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -175,23 +175,18 @@ class RetrievalService: dataset_id: str, query: str, external_retrieval_model: dict[str, Any] | None = None, - metadata_filtering_conditions: dict[str, Any] | None = None, + metadata_filtering_conditions: MetadataFilteringCondition | None = None, ): stmt = select(Dataset).where(Dataset.id == dataset_id) dataset = db.session.scalar(stmt) if not dataset: return [] - metadata_condition = ( - MetadataFilteringCondition.model_validate(metadata_filtering_conditions) - if metadata_filtering_conditions - else None - ) all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( dataset.tenant_id, dataset_id, query, external_retrieval_model or {}, - metadata_condition=metadata_condition, + metadata_condition=metadata_filtering_conditions, ) return all_documents diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 2e5987dd28..14b95994f5 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -142,11 +142,19 @@ class HitTestingService: start = time.perf_counter() + from core.rag.entities import MetadataFilteringCondition + + validated_metadata_conditions = ( + MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + if metadata_filtering_conditions + else None + ) + all_documents = RetrievalService.external_retrieve( dataset_id=dataset.id, query=cls.escape_query_for_search(query), external_retrieval_model=external_retrieval_model, - metadata_filtering_conditions=metadata_filtering_conditions, + metadata_filtering_conditions=validated_metadata_conditions, ) end = time.perf_counter() diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index b0ecad4d0c..b6cefb4daa 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -226,11 +226,35 @@ class TestRetrievalServiceInternals: assert mock_retrieve.call_count == 2 @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") - @patch("core.rag.datasource.retrieval_service.MetadataFilteringCondition.model_validate") @patch("core.rag.datasource.retrieval_service.db.session.scalar") - def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): + def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_fetch): + from core.rag.entities import MetadataFilteringCondition + + mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") + expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] + mock_fetch.return_value = expected_documents + metadata_condition = MetadataFilteringCondition(logical_operator="and", conditions=[]) + + results = RetrievalService.external_retrieve( + dataset_id="dataset-1", + query="test query", + external_retrieval_model={"top_k": 3}, + metadata_filtering_conditions=metadata_condition, + ) + + assert results == expected_documents + mock_fetch.assert_called_once_with( + "tenant-1", + "dataset-1", + "test query", + {"top_k": 3}, + metadata_condition=metadata_condition, + ) + + @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_without_metadata_conditions(self, mock_scalar, mock_fetch): mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") - mock_validate.return_value = "validated-condition" expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] mock_fetch.return_value = expected_documents @@ -238,17 +262,15 @@ class TestRetrievalServiceInternals: dataset_id="dataset-1", query="test query", external_retrieval_model={"top_k": 3}, - metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"}, ) assert results == expected_documents - mock_validate.assert_called_once() mock_fetch.assert_called_once_with( "tenant-1", "dataset-1", "test query", {"top_k": 3}, - metadata_condition="validated-condition", + metadata_condition=None, ) @patch("core.rag.datasource.retrieval_service.db.session.scalar") diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py index ddbc7dc041..a4b6fd874c 100644 --- a/api/tests/unit_tests/services/hit_service.py +++ b/api/tests/unit_tests/services/hit_service.py @@ -467,15 +467,21 @@ class TestHitTestingServiceExternalRetrieve: """ Test external retrieval with metadata filtering conditions. - Verifies that metadata filtering conditions are properly passed - to the external retrieval service. + Verifies that metadata filtering conditions are validated into a + MetadataFilteringCondition instance at the service boundary before + being forwarded to the external retrieval service. """ + from core.rag.entities import MetadataFilteringCondition + # Arrange dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external") account = HitTestingTestDataFactory.create_user_mock() query = "test query" external_retrieval_model = {"top_k": 3} - metadata_filtering_conditions = {"category": "test"} + metadata_filtering_conditions = { + "logical_operator": "and", + "conditions": [{"name": "category", "comparison_operator": "contains", "value": "test"}], + } external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}] @@ -497,7 +503,9 @@ class TestHitTestingServiceExternalRetrieve: assert result["query"]["content"] == query assert len(result["records"]) == 1 call_kwargs = mock_external_retrieve.call_args[1] - assert call_kwargs["metadata_filtering_conditions"] == metadata_filtering_conditions + forwarded_condition = call_kwargs["metadata_filtering_conditions"] + assert isinstance(forwarded_condition, MetadataFilteringCondition) + assert forwarded_condition == MetadataFilteringCondition.model_validate(metadata_filtering_conditions) def test_external_retrieve_empty_documents(self, mock_db_session): """