This commit is contained in:
wangxiaolei 2025-12-29 09:20:47 +08:00 committed by GitHub
commit bc66d05403
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 283 additions and 1 deletions

View File

@ -515,6 +515,7 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
dataset_count = len(available_datasets)
with measure_time() as timer:
if query:
query_thread = threading.Thread(
@ -534,6 +535,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"dataset_count": dataset_count,
},
)
all_threads.append(query_thread)
@ -557,6 +559,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"dataset_count": dataset_count,
},
)
all_threads.append(attachment_thread)
@ -1404,6 +1407,7 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
dataset_count: int,
):
with flask_app.app_context():
threads = []
@ -1439,7 +1443,8 @@ class DatasetRetrieval:
for thread in threads:
thread.join()
if reranking_enable:
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:

View File

@ -73,6 +73,7 @@ import pytest
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from models.dataset import Dataset
@ -1507,6 +1508,282 @@ class TestRetrievalService:
call_kwargs = mock_retrieve.call_args.kwargs
assert call_kwargs["reranking_model"] == reranking_model
# ==================== Multiple Retrieve Thread Tests ====================
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset(
self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
):
"""
Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1.
When there is only one dataset, the second reranking is unnecessary
because the documents are already ranked from the first retrieval.
This optimization avoids the overhead of reranking when it won't
provide any benefit.
Verifies:
- DataPostProcessor is NOT called when dataset_count == 1
- Documents are still added to all_documents
- Standard scoring logic is applied instead
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
tenant_id = str(uuid4())
# Create test documents
doc1 = Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
doc2 = Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
# Mock _retriever to return documents
def side_effect_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.extend([doc1, doc2])
mock_retriever.side_effect = side_effect_retriever
# Set up dataset with high_quality indexing
mock_dataset.indexing_technique = "high_quality"
all_documents = []
# Act - Call with dataset_count = 1
dataset_retrieval._multiple_retrieve_thread(
flask_app=mock_flask_app,
available_datasets=[mock_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
weights=None,
top_k=5,
score_threshold=0.5,
query="test query",
attachment_id=None,
dataset_count=1, # Single dataset - should skip second reranking
)
# Assert
# DataPostProcessor should NOT be called (second reranking skipped)
mock_data_processor_class.assert_not_called()
# Documents should still be added to all_documents
assert len(all_documents) == 2
assert all_documents[0].page_content == "Test content 1"
assert all_documents[1].page_content == "Test content 2"
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets(
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
):
"""
Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1.
When there are multiple datasets, the second reranking is necessary
to merge and re-rank results from different datasets. This ensures
the most relevant documents across all datasets are returned.
Verifies:
- DataPostProcessor IS called when dataset_count > 1
- Reranking is applied with correct parameters
- Documents are processed correctly
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
tenant_id = str(uuid4())
# Create test documents
doc1 = Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
doc2 = Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
# Mock _retriever to return documents
def side_effect_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.extend([doc1, doc2])
mock_retriever.side_effect = side_effect_retriever
# Set up dataset with high_quality indexing
mock_dataset.indexing_technique = "high_quality"
# Mock DataPostProcessor instance and its invoke method
mock_processor_instance = Mock()
# Simulate reranking - return documents in reversed order with updated scores
reranked_docs = [
Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
),
Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
),
]
mock_processor_instance.invoke.return_value = reranked_docs
mock_data_processor_class.return_value = mock_processor_instance
all_documents = []
# Create second dataset
mock_dataset2 = Mock(spec=Dataset)
mock_dataset2.id = str(uuid4())
mock_dataset2.indexing_technique = "high_quality"
mock_dataset2.provider = "dify"
# Act - Call with dataset_count = 2
dataset_retrieval._multiple_retrieve_thread(
flask_app=mock_flask_app,
available_datasets=[mock_dataset, mock_dataset2],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
weights=None,
top_k=5,
score_threshold=0.5,
query="test query",
attachment_id=None,
dataset_count=2, # Multiple datasets - should perform second reranking
)
# Assert
# DataPostProcessor SHOULD be called (second reranking performed)
mock_data_processor_class.assert_called_once_with(
tenant_id,
"reranking_model",
{"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
None,
False,
)
# Verify invoke was called with correct parameters
mock_processor_instance.invoke.assert_called_once()
# Documents should be added to all_documents after reranking
assert len(all_documents) == 2
# The reranked order should be reflected
assert all_documents[0].page_content == "Test content 2"
assert all_documents[1].page_content == "Test content 1"
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring(
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
):
"""
Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1
and reranking is enabled.
When there's only one dataset, instead of using DataPostProcessor,
the method should fall through to the standard scoring logic
(calculate_vector_score for high_quality datasets).
Verifies:
- DataPostProcessor is NOT called
- calculate_vector_score IS called for high_quality indexing
- Documents are scored correctly
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
tenant_id = str(uuid4())
# Create test documents
doc1 = Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
doc2 = Document(
page_content="Test content 2",
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
)
# Mock _retriever to return documents
def side_effect_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.extend([doc1, doc2])
mock_retriever.side_effect = side_effect_retriever
# Set up dataset with high_quality indexing
mock_dataset.indexing_technique = "high_quality"
# Mock calculate_vector_score to return scored documents
scored_docs = [
Document(
page_content="Test content 1",
metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
provider="dify",
),
]
mock_calculate_vector_score.return_value = scored_docs
all_documents = []
# Act - Call with dataset_count = 1
dataset_retrieval._multiple_retrieve_thread(
flask_app=mock_flask_app,
available_datasets=[mock_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True, # Reranking enabled but should be skipped for single dataset
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
weights=None,
top_k=5,
score_threshold=0.5,
query="test query",
attachment_id=None,
dataset_count=1,
)
# Assert
# DataPostProcessor should NOT be called
mock_data_processor_class.assert_not_called()
# calculate_vector_score SHOULD be called for high_quality datasets
mock_calculate_vector_score.assert_called_once()
call_args = mock_calculate_vector_score.call_args
assert call_args[0][1] == 5 # top_k
# Documents should be added after standard scoring
assert len(all_documents) == 1
assert all_documents[0].page_content == "Test content 1"
class TestRetrievalMethods:
"""