chore(api): convert MessagesCleanPolicy from ABC to Protocol (#37171)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eric Cao 2026-06-08 17:55:52 +08:00 committed by GitHub
parent a15ecf6bec
commit 0239b81cca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 78 additions and 9 deletions

View File

@ -862,15 +862,20 @@ class RetrievalService:
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
query = query or attachment_id
if not query:
if query:
rerank_query = query
query_type = QueryType.TEXT_QUERY
elif attachment_id:
rerank_query = attachment_id
query_type = QueryType.IMAGE_QUERY
else:
return
all_documents_item = data_post_processor.invoke(
query=query,
query=rerank_query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY,
query_type=query_type,
)
if not data_post_processor.rerank_runner and score_threshold:
all_documents_item = self._filter_documents_by_vector_score_threshold(

View File

@ -1,9 +1,8 @@
import datetime
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import override
from typing import Protocol, override
from configs import dify_config
from enums.cloud_plan import CloudPlan
@ -19,14 +18,13 @@ class SimpleMessage:
created_at: datetime.datetime
class MessagesCleanPolicy(ABC):
class MessagesCleanPolicy(Protocol):
"""
Abstract base class for message cleanup policies.
Protocol for message cleanup policies.
A policy determines which messages from a batch should be deleted.
"""
@abstractmethod
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],

View File

@ -980,6 +980,72 @@ class TestRetrievalService:
# Weights might be in positional args (position 3)
assert len(call_args.args) >= 4
@pytest.mark.parametrize("empty_query", ["", None])
@patch("core.rag.datasource.retrieval_service.DataPostProcessor")
@patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search")
@patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset")
def test_hybrid_search_attachment_only_uses_image_query_type(
self,
mock_get_dataset,
mock_embedding_search,
mock_data_processor_class,
mock_dataset,
sample_documents,
empty_query,
):
"""
Regression test for GH #37116: attachment-only hybrid retrieval must use IMAGE_QUERY.
When HYBRID_SEARCH is invoked with no text query and a non-None attachment_id,
DataPostProcessor.invoke must receive query_type=QueryType.IMAGE_QUERY.
"""
from core.rag.index_processor.constant.query_type import QueryType
# Arrange
mock_get_dataset.return_value = mock_dataset
attachment_id = "upload-file-uuid-1234"
def side_effect_embedding(
flask_app,
dataset_id,
query,
top_k,
score_threshold,
reranking_model,
all_documents,
retrieval_method,
exceptions,
document_ids_filter=None,
query_type=QueryType.TEXT_QUERY,
):
all_documents.extend(sample_documents[:2])
mock_embedding_search.side_effect = side_effect_embedding
mock_processor_instance = Mock()
mock_processor_instance.invoke.return_value = sample_documents[:2]
mock_data_processor_class.return_value = mock_processor_instance
# Act: call retrieve with attachment_ids only, no text query
RetrievalService.retrieve(
retrieval_method=RetrievalMethod.HYBRID_SEARCH,
dataset_id=mock_dataset.id,
query=empty_query,
top_k=3,
score_threshold=0.5,
attachment_ids=[attachment_id],
)
# Assert: invoke must have been called with IMAGE_QUERY
mock_processor_instance.invoke.assert_called_once()
invoke_kwargs = mock_processor_instance.invoke.call_args.kwargs
assert invoke_kwargs["query_type"] == QueryType.IMAGE_QUERY, (
"Attachment-only hybrid search must use IMAGE_QUERY for reranking, not TEXT_QUERY"
)
assert invoke_kwargs["query"] == attachment_id, (
"The rerank query must be the attachment_id, not the empty text query"
)
# ==================== Full-Text Search Tests ====================
@patch("core.rag.datasource.retrieval_service.RetrievalService.full_text_index_search")