mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 21:28:25 +08:00
refactor: use MetadataFilteringCondition BaseModel in RetrievalService.external_retrieve
Follow the dict -> Pydantic BaseModel pattern established by #31514, #34080 and #34422: RetrievalService.external_retrieve now accepts metadata_filtering_conditions: MetadataFilteringCondition | None directly instead of dict[str, Any] | None. The redundant internal MetadataFilteringCondition.model_validate(...) call is removed; validation now happens at the call site in HitTestingService.external_retrieve, following the "validate at the boundary" principle cited in the issue. Unit tests are updated to pass typed instances where appropriate and a regression test is added for the None path. Behavior is otherwise preserved. external_retrieval_model is left as dict[str, Any] because it is an opaque payload forwarded to external retrieval providers. Fixes part of #31497.
This commit is contained in:
parent
8fd616d27f
commit
02fa84d853
@ -175,23 +175,18 @@ class RetrievalService:
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: dict[str, Any] | None = None,
|
||||
metadata_filtering_conditions: dict[str, Any] | None = None,
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
if metadata_filtering_conditions
|
||||
else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model or {},
|
||||
metadata_condition=metadata_condition,
|
||||
metadata_condition=metadata_filtering_conditions,
|
||||
)
|
||||
return all_documents
|
||||
|
||||
|
||||
@ -142,11 +142,19 @@ class HitTestingService:
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
|
||||
validated_metadata_conditions = (
|
||||
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
if metadata_filtering_conditions
|
||||
else None
|
||||
)
|
||||
|
||||
all_documents = RetrievalService.external_retrieve(
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
metadata_filtering_conditions=metadata_filtering_conditions,
|
||||
metadata_filtering_conditions=validated_metadata_conditions,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
||||
@ -226,11 +226,35 @@ class TestRetrievalServiceInternals:
|
||||
assert mock_retrieve.call_count == 2
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval")
|
||||
@patch("core.rag.datasource.retrieval_service.MetadataFilteringCondition.model_validate")
|
||||
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
||||
def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch):
|
||||
def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_fetch):
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
|
||||
mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1")
|
||||
expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")]
|
||||
mock_fetch.return_value = expected_documents
|
||||
metadata_condition = MetadataFilteringCondition(logical_operator="and", conditions=[])
|
||||
|
||||
results = RetrievalService.external_retrieve(
|
||||
dataset_id="dataset-1",
|
||||
query="test query",
|
||||
external_retrieval_model={"top_k": 3},
|
||||
metadata_filtering_conditions=metadata_condition,
|
||||
)
|
||||
|
||||
assert results == expected_documents
|
||||
mock_fetch.assert_called_once_with(
|
||||
"tenant-1",
|
||||
"dataset-1",
|
||||
"test query",
|
||||
{"top_k": 3},
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval")
|
||||
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
||||
def test_external_retrieve_without_metadata_conditions(self, mock_scalar, mock_fetch):
|
||||
mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1")
|
||||
mock_validate.return_value = "validated-condition"
|
||||
expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")]
|
||||
mock_fetch.return_value = expected_documents
|
||||
|
||||
@ -238,17 +262,15 @@ class TestRetrievalServiceInternals:
|
||||
dataset_id="dataset-1",
|
||||
query="test query",
|
||||
external_retrieval_model={"top_k": 3},
|
||||
metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"},
|
||||
)
|
||||
|
||||
assert results == expected_documents
|
||||
mock_validate.assert_called_once()
|
||||
mock_fetch.assert_called_once_with(
|
||||
"tenant-1",
|
||||
"dataset-1",
|
||||
"test query",
|
||||
{"top_k": 3},
|
||||
metadata_condition="validated-condition",
|
||||
metadata_condition=None,
|
||||
)
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
||||
|
||||
@ -467,15 +467,21 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
"""
|
||||
Test external retrieval with metadata filtering conditions.
|
||||
|
||||
Verifies that metadata filtering conditions are properly passed
|
||||
to the external retrieval service.
|
||||
Verifies that metadata filtering conditions are validated into a
|
||||
MetadataFilteringCondition instance at the service boundary before
|
||||
being forwarded to the external retrieval service.
|
||||
"""
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
|
||||
# Arrange
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
query = "test query"
|
||||
external_retrieval_model = {"top_k": 3}
|
||||
metadata_filtering_conditions = {"category": "test"}
|
||||
metadata_filtering_conditions = {
|
||||
"logical_operator": "and",
|
||||
"conditions": [{"name": "category", "comparison_operator": "contains", "value": "test"}],
|
||||
}
|
||||
|
||||
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
|
||||
|
||||
@ -497,7 +503,9 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
assert result["query"]["content"] == query
|
||||
assert len(result["records"]) == 1
|
||||
call_kwargs = mock_external_retrieve.call_args[1]
|
||||
assert call_kwargs["metadata_filtering_conditions"] == metadata_filtering_conditions
|
||||
forwarded_condition = call_kwargs["metadata_filtering_conditions"]
|
||||
assert isinstance(forwarded_condition, MetadataFilteringCondition)
|
||||
assert forwarded_condition == MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||
|
||||
def test_external_retrieve_empty_documents(self, mock_db_session):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user