mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor: use MetadataFilteringCondition BaseModel in RetrievalService.external_retrieve
Follow the dict -> Pydantic BaseModel pattern established by #31514, #34080 and #34422: RetrievalService.external_retrieve now accepts metadata_filtering_conditions: MetadataFilteringCondition | None directly instead of dict[str, Any] | None. The redundant internal MetadataFilteringCondition.model_validate(...) call is removed; validation now happens at the call site in HitTestingService.external_retrieve, following the "validate at the boundary" principle cited in the issue. Unit tests are updated to pass typed instances where appropriate and a regression test is added for the None path. Behavior is otherwise preserved. external_retrieval_model is left as dict[str, Any] because it is an opaque payload forwarded to external retrieval providers. Fixes part of #31497.
This commit is contained in:
parent
8fd616d27f
commit
02fa84d853
@ -175,23 +175,18 @@ class RetrievalService:
|
|||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
external_retrieval_model: dict[str, Any] | None = None,
|
external_retrieval_model: dict[str, Any] | None = None,
|
||||||
metadata_filtering_conditions: dict[str, Any] | None = None,
|
metadata_filtering_conditions: MetadataFilteringCondition | None = None,
|
||||||
):
|
):
|
||||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||||
dataset = db.session.scalar(stmt)
|
dataset = db.session.scalar(stmt)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
return []
|
return []
|
||||||
metadata_condition = (
|
|
||||||
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
|
||||||
if metadata_filtering_conditions
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
dataset.tenant_id,
|
dataset.tenant_id,
|
||||||
dataset_id,
|
dataset_id,
|
||||||
query,
|
query,
|
||||||
external_retrieval_model or {},
|
external_retrieval_model or {},
|
||||||
metadata_condition=metadata_condition,
|
metadata_condition=metadata_filtering_conditions,
|
||||||
)
|
)
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
|
|||||||
@ -142,11 +142,19 @@ class HitTestingService:
|
|||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
|
|
||||||
|
validated_metadata_conditions = (
|
||||||
|
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||||
|
if metadata_filtering_conditions
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
all_documents = RetrievalService.external_retrieve(
|
all_documents = RetrievalService.external_retrieve(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=cls.escape_query_for_search(query),
|
query=cls.escape_query_for_search(query),
|
||||||
external_retrieval_model=external_retrieval_model,
|
external_retrieval_model=external_retrieval_model,
|
||||||
metadata_filtering_conditions=metadata_filtering_conditions,
|
metadata_filtering_conditions=validated_metadata_conditions,
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|||||||
@ -226,11 +226,35 @@ class TestRetrievalServiceInternals:
|
|||||||
assert mock_retrieve.call_count == 2
|
assert mock_retrieve.call_count == 2
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval")
|
@patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval")
|
||||||
@patch("core.rag.datasource.retrieval_service.MetadataFilteringCondition.model_validate")
|
|
||||||
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
||||||
def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch):
|
def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_fetch):
|
||||||
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
|
|
||||||
|
mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1")
|
||||||
|
expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")]
|
||||||
|
mock_fetch.return_value = expected_documents
|
||||||
|
metadata_condition = MetadataFilteringCondition(logical_operator="and", conditions=[])
|
||||||
|
|
||||||
|
results = RetrievalService.external_retrieve(
|
||||||
|
dataset_id="dataset-1",
|
||||||
|
query="test query",
|
||||||
|
external_retrieval_model={"top_k": 3},
|
||||||
|
metadata_filtering_conditions=metadata_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results == expected_documents
|
||||||
|
mock_fetch.assert_called_once_with(
|
||||||
|
"tenant-1",
|
||||||
|
"dataset-1",
|
||||||
|
"test query",
|
||||||
|
{"top_k": 3},
|
||||||
|
metadata_condition=metadata_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval")
|
||||||
|
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
||||||
|
def test_external_retrieve_without_metadata_conditions(self, mock_scalar, mock_fetch):
|
||||||
mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1")
|
mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1")
|
||||||
mock_validate.return_value = "validated-condition"
|
|
||||||
expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")]
|
expected_documents = [create_mock_document("external-doc", "external-1", 0.8, provider="external")]
|
||||||
mock_fetch.return_value = expected_documents
|
mock_fetch.return_value = expected_documents
|
||||||
|
|
||||||
@ -238,17 +262,15 @@ class TestRetrievalServiceInternals:
|
|||||||
dataset_id="dataset-1",
|
dataset_id="dataset-1",
|
||||||
query="test query",
|
query="test query",
|
||||||
external_retrieval_model={"top_k": 3},
|
external_retrieval_model={"top_k": 3},
|
||||||
metadata_filtering_conditions={"field": "source", "operator": "contains", "value": "manual"},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert results == expected_documents
|
assert results == expected_documents
|
||||||
mock_validate.assert_called_once()
|
|
||||||
mock_fetch.assert_called_once_with(
|
mock_fetch.assert_called_once_with(
|
||||||
"tenant-1",
|
"tenant-1",
|
||||||
"dataset-1",
|
"dataset-1",
|
||||||
"test query",
|
"test query",
|
||||||
{"top_k": 3},
|
{"top_k": 3},
|
||||||
metadata_condition="validated-condition",
|
metadata_condition=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
@patch("core.rag.datasource.retrieval_service.db.session.scalar")
|
||||||
|
|||||||
@ -467,15 +467,21 @@ class TestHitTestingServiceExternalRetrieve:
|
|||||||
"""
|
"""
|
||||||
Test external retrieval with metadata filtering conditions.
|
Test external retrieval with metadata filtering conditions.
|
||||||
|
|
||||||
Verifies that metadata filtering conditions are properly passed
|
Verifies that metadata filtering conditions are validated into a
|
||||||
to the external retrieval service.
|
MetadataFilteringCondition instance at the service boundary before
|
||||||
|
being forwarded to the external retrieval service.
|
||||||
"""
|
"""
|
||||||
|
from core.rag.entities import MetadataFilteringCondition
|
||||||
|
|
||||||
# Arrange
|
# Arrange
|
||||||
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
dataset = HitTestingTestDataFactory.create_dataset_mock(provider="external")
|
||||||
account = HitTestingTestDataFactory.create_user_mock()
|
account = HitTestingTestDataFactory.create_user_mock()
|
||||||
query = "test query"
|
query = "test query"
|
||||||
external_retrieval_model = {"top_k": 3}
|
external_retrieval_model = {"top_k": 3}
|
||||||
metadata_filtering_conditions = {"category": "test"}
|
metadata_filtering_conditions = {
|
||||||
|
"logical_operator": "and",
|
||||||
|
"conditions": [{"name": "category", "comparison_operator": "contains", "value": "test"}],
|
||||||
|
}
|
||||||
|
|
||||||
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
|
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
|
||||||
|
|
||||||
@ -497,7 +503,9 @@ class TestHitTestingServiceExternalRetrieve:
|
|||||||
assert result["query"]["content"] == query
|
assert result["query"]["content"] == query
|
||||||
assert len(result["records"]) == 1
|
assert len(result["records"]) == 1
|
||||||
call_kwargs = mock_external_retrieve.call_args[1]
|
call_kwargs = mock_external_retrieve.call_args[1]
|
||||||
assert call_kwargs["metadata_filtering_conditions"] == metadata_filtering_conditions
|
forwarded_condition = call_kwargs["metadata_filtering_conditions"]
|
||||||
|
assert isinstance(forwarded_condition, MetadataFilteringCondition)
|
||||||
|
assert forwarded_condition == MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
|
||||||
|
|
||||||
def test_external_retrieve_empty_documents(self, mock_db_session):
|
def test_external_retrieve_empty_documents(self, mock_db_session):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user