From 54bde0bdf61fb902330fe79133740ea00ff48f44 Mon Sep 17 00:00:00 2001 From: Prince Pal <107296821+princepal9120@users.noreply.github.com> Date: Fri, 1 May 2026 14:50:14 +0530 Subject: [PATCH] fix(api): preserve dataset metadata filters (#35700) --- .../knowledge_entities/knowledge_entities.py | 2 + .../service_api/dataset/test_hit_testing.py | 56 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index b1fe352861..910f54bebc 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -3,6 +3,7 @@ from typing import Any, Literal from pydantic import BaseModel, field_validator from core.rag.entities import Rule +from core.rag.entities.metadata_entities import MetadataFilteringCondition from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -83,6 +84,7 @@ class RetrievalModel(BaseModel): score_threshold_enabled: bool score_threshold: float | None = None weights: WeightModel | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None class MetaDataConfig(BaseModel): diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 9be8e56f56..a26cdf6563 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -171,6 +171,62 @@ class TestHitTestingApiPost: assert passed_retrieval_model["search_method"] == "semantic_search" assert passed_retrieval_model["top_k"] == 10 + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") + @patch("controllers.console.datasets.hit_testing_base.marshal") + @patch("controllers.console.datasets.hit_testing_base.HitTestingService") + @patch("controllers.console.datasets.hit_testing_base.DatasetService") + @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) + def test_post_preserves_retrieval_model_metadata_filtering_conditions( + self, + mock_current_user, + mock_dataset_svc, + mock_hit_svc, + mock_marshal, + mock_ns, + app, + ): + """Service API retrieval payload should not drop metadata filters.""" + dataset_id = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + mock_dataset = Mock() + mock_dataset.id = dataset_id + + mock_dataset_svc.get_dataset.return_value = mock_dataset + mock_dataset_svc.check_dataset_permission.return_value = None + mock_hit_svc.retrieve.return_value = {"query": "filtered query", "records": []} + mock_hit_svc.hit_testing_args_check.return_value = None + mock_marshal.return_value = [] + + metadata_filtering_conditions = { + "logical_operator": "and", + "conditions": [ + { + "name": "category", + "comparison_operator": "is", + "value": "finance", + } + ], + } + mock_ns.payload = { + "query": "filtered query", + "retrieval_model": { + "search_method": "semantic_search", + "reranking_enable": False, + "score_threshold_enabled": False, + "top_k": 4, + "metadata_filtering_conditions": metadata_filtering_conditions, + }, + } + + with app.test_request_context(): + api = HitTestingApi() + HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) + + passed_retrieval_model = mock_hit_svc.retrieve.call_args.kwargs.get("retrieval_model") + assert passed_retrieval_model is not None + assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.marshal") @patch("controllers.console.datasets.hit_testing_base.HitTestingService")