mirror of
https://github.com/langgenius/dify.git
synced 2026-06-16 14:01:10 +08:00
fix(api): preserve dataset metadata filters (#35700)
This commit is contained in:
parent
87add9a4f3
commit
54bde0bdf6
@ -3,6 +3,7 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.rag.entities import Rule
|
||||
from core.rag.entities.metadata_entities import MetadataFilteringCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
|
||||
@ -83,6 +84,7 @@ class RetrievalModel(BaseModel):
|
||||
score_threshold_enabled: bool
|
||||
score_threshold: float | None = None
|
||||
weights: WeightModel | None = None
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None
|
||||
|
||||
|
||||
class MetaDataConfig(BaseModel):
|
||||
|
||||
@ -171,6 +171,62 @@ class TestHitTestingApiPost:
|
||||
assert passed_retrieval_model["search_method"] == "semantic_search"
|
||||
assert passed_retrieval_model["top_k"] == 10
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.DatasetService")
|
||||
@patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account))
|
||||
def test_post_preserves_retrieval_model_metadata_filtering_conditions(
|
||||
self,
|
||||
mock_current_user,
|
||||
mock_dataset_svc,
|
||||
mock_hit_svc,
|
||||
mock_marshal,
|
||||
mock_ns,
|
||||
app,
|
||||
):
|
||||
"""Service API retrieval payload should not drop metadata filters."""
|
||||
dataset_id = str(uuid.uuid4())
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.id = dataset_id
|
||||
|
||||
mock_dataset_svc.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_permission.return_value = None
|
||||
mock_hit_svc.retrieve.return_value = {"query": "filtered query", "records": []}
|
||||
mock_hit_svc.hit_testing_args_check.return_value = None
|
||||
mock_marshal.return_value = []
|
||||
|
||||
metadata_filtering_conditions = {
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "category",
|
||||
"comparison_operator": "is",
|
||||
"value": "finance",
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_ns.payload = {
|
||||
"query": "filtered query",
|
||||
"retrieval_model": {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
"top_k": 4,
|
||||
"metadata_filtering_conditions": metadata_filtering_conditions,
|
||||
},
|
||||
}
|
||||
|
||||
with app.test_request_context():
|
||||
api = HitTestingApi()
|
||||
HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id)
|
||||
|
||||
passed_retrieval_model = mock_hit_svc.retrieve.call_args.kwargs.get("retrieval_model")
|
||||
assert passed_retrieval_model is not None
|
||||
assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions
|
||||
|
||||
@patch("controllers.service_api.dataset.hit_testing.service_api_ns")
|
||||
@patch("controllers.console.datasets.hit_testing_base.marshal")
|
||||
@patch("controllers.console.datasets.hit_testing_base.HitTestingService")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user