fix(api): preserve dataset metadata filters (#35700)

This commit is contained in:
Prince Pal 2026-05-01 14:50:14 +05:30 committed by GitHub
parent 87add9a4f3
commit 54bde0bdf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 0 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Literal
from pydantic import BaseModel, field_validator
from core.rag.entities import Rule
from core.rag.entities.metadata_entities import MetadataFilteringCondition
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -83,6 +84,7 @@ class RetrievalModel(BaseModel):
score_threshold_enabled: bool
score_threshold: float | None = None
weights: WeightModel | None = None
metadata_filtering_conditions: MetadataFilteringCondition | None = None
class MetaDataConfig(BaseModel):

View File

@ -171,6 +171,62 @@ class TestHitTestingApiPost:
assert passed_retrieval_model["search_method"] == "semantic_search"
assert passed_retrieval_model["top_k"] == 10
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
def test_post_preserves_retrieval_model_metadata_filtering_conditions(
self,
mock_current_user,
mock_dataset_svc,
mock_hit_svc,
mock_marshal,
mock_ns,
app,
):
"""Service API retrieval payload should not drop metadata filters."""
dataset_id = str(uuid.uuid4())
tenant_id = str(uuid.uuid4())
mock_dataset = Mock()
mock_dataset.id = dataset_id
mock_dataset_svc.get_dataset.return_value = mock_dataset
mock_dataset_svc.check_dataset_permission.return_value = None
mock_hit_svc.retrieve.return_value = {"query": "filtered query", "records": []}
mock_hit_svc.hit_testing_args_check.return_value = None
mock_marshal.return_value = []
metadata_filtering_conditions = {
"logical_operator": "and",
"conditions": [
{
"name": "category",
"comparison_operator": "is",
"value": "finance",
}
],
}
mock_ns.payload = {
"query": "filtered query",
"retrieval_model": {
"search_method": "semantic_search",
"reranking_enable": False,
"score_threshold_enabled": False,
"top_k": 4,
"metadata_filtering_conditions": metadata_filtering_conditions,
},
}
with app.test_request_context():
api = HitTestingApi()
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
passed_retrieval_model = mock_hit_svc.retrieve.call_args.kwargs.get("retrieval_model")
assert passed_retrieval_model is not None
assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
@patch("controllers.console.datasets.hit_testing_base.marshal")
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")