From 02fa84d853e1d797bb4ffadb300355488d7e70fc Mon Sep 17 00:00:00 2001 From: gaurav0107 Date: Thu, 7 May 2026 10:31:37 +0530 Subject: [PATCH] refactor: use MetadataFilteringCondition BaseModel in RetrievalService.external_retrieve Follow the dict -> Pydantic BaseModel pattern established by #31514, #34080 and #34422: RetrievalService.external_retrieve now accepts metadata_filtering_conditions: MetadataFilteringCondition | None directly instead of dict[str, Any] | None. The redundant internal MetadataFilteringCondition.model_validate(...) call is removed; validation now happens at the call site in HitTestingService.external_retrieve, following the "validate at the boundary" principle cited in the issue. Unit tests are updated to pass typed instances where appropriate and a regression test is added for the None path. Behavior is otherwise preserved. external_retrieval_model is left as dict[str, Any] because it is an opaque payload forwarded to external retrieval providers. Fixes part of #31497. --- api/core/rag/datasource/retrieval_service.py | 9 ++--- api/services/hit_testing_service.py | 10 +++++- .../datasource/test_datasource_retrieval.py | 34 +++++++++++++++---- api/tests/unit_tests/services/hit_service.py | 16 ++++++--- 4 files changed, 51 insertions(+), 18 deletions(-) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index b985ebbe1d..16fb7ba146 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -175,23 +175,18 @@ class RetrievalService: dataset_id: str, query: str, external_retrieval_model: dict[str, Any] | None = None, - metadata_filtering_conditions: dict[str, Any] | None = None, + metadata_filtering_conditions: MetadataFilteringCondition | None = None, ): stmt = select(Dataset).where(Dataset.id == dataset_id) dataset = db.session.scalar(stmt) if not dataset: return [] - metadata_condition = ( - MetadataFilteringCondition.model_validate(metadata_filtering_conditions) - if metadata_filtering_conditions - else None - ) all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( dataset.tenant_id, dataset_id, query, external_retrieval_model or {}, - metadata_condition=metadata_condition, + metadata_condition=metadata_filtering_conditions, ) return all_documents diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 2e5987dd28..14b95994f5 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -142,11 +142,19 @@ class HitTestingService: start = time.perf_counter() + from core.rag.entities import MetadataFilteringCondition + + validated_metadata_conditions = ( + MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + if metadata_filtering_conditions + else None + ) + all_documents = RetrievalService.external_retrieve( dataset_id=dataset.id, query=cls.escape_query_for_search(query), external_retrieval_model=external_retrieval_model, - metadata_filtering_conditions=metadata_filtering_conditions, + metadata_filtering_conditions=validated_metadata_conditions, ) end = time.perf_counter() diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index b0ecad4d0c..b6cefb4daa 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -226,11 +226,35 @@ class TestRetrievalServiceInternals: assert mock_retrieve.call_count == 2 @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") - @patch("core.rag.datasource.retrieval_service.MetadataFilteringCondition.model_validate") @patch("core.rag.datasource.retrieval_service.db.session.scalar") - def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): + def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_fetch): + from core.rag.entities import MetadataFilteringCondition + + mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") + expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] + mock_fetch.return_value = expected_documents + metadata_condition = MetadataFilteringCondition(logical_operator="and", conditions=[]) + + results = RetrievalService.external_retrieve( + dataset_id="dataset-1", + query="test query", + external_retrieval_model={"top_k": 3}, + metadata_filtering_conditions=metadata_condition, + ) + + assert results == expected_documents + mock_fetch.assert_called_once_with( + "tenant-1", + "dataset-1", + "test query", + {"top_k": 3}, + metadata_condition=metadata_condition, + ) + + @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") + @patch("core.rag.datasource.retrieval_service.db.session.scalar") + def test_external_retrieve_without_metadata_conditions(self, mock_scalar, mock_fetch): mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") - mock_validate.return_value = "validated-condition" expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")] mock_fetch.return_value = expected_documents @@ -238,17 +262,15 @@ class TestRetrievalServiceInternals: dataset_id="dataset-1", query="test query", external_retrieval_model={"top_k": 3}, - metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"}, ) assert results == expected_documents - mock_validate.assert_called_once() mock_fetch.assert_called_once_with( "tenant-1", "dataset-1", "test query", {"top_k": 3}, - metadata_condition="validated-condition", + metadata_condition=None, ) @patch("core.rag.datasource.retrieval_service.db.session.scalar") diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py index ddbc7dc041..a4b6fd874c 100644 --- a/api/tests/unit_tests/services/hit_service.py +++ b/api/tests/unit_tests/services/hit_service.py @@ -467,15 +467,21 @@ class TestHitTestingServiceExternalRetrieve: """ Test external retrieval with metadata filtering conditions. - Verifies that metadata filtering conditions are properly passed - to the external retrieval service. + Verifies that metadata filtering conditions are validated into a + MetadataFilteringCondition instance at the service boundary before + being forwarded to the external retrieval service. """ + from core.rag.entities import MetadataFilteringCondition + # Arrange dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external") account = HitTestingTestDataFactory.create_user_mock() query = "test query" external_retrieval_model = {"top_k": 3} - metadata_filtering_conditions = {"category": "test"} + metadata_filtering_conditions = { + "logical_operator": "and", + "conditions": [{"name": "category", "comparison_operator": "contains", "value": "test"}], + } external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}] @@ -497,7 +503,9 @@ class TestHitTestingServiceExternalRetrieve: assert result["query"]["content"] == query assert len(result["records"]) == 1 call_kwargs = mock_external_retrieve.call_args[1] - assert call_kwargs["metadata_filtering_conditions"] == metadata_filtering_conditions + forwarded_condition = call_kwargs["metadata_filtering_conditions"] + assert isinstance(forwarded_condition, MetadataFilteringCondition) + assert forwarded_condition == MetadataFilteringCondition.model_validate(metadata_filtering_conditions) def test_external_retrieve_empty_documents(self, mock_db_session): """