mirror of https://github.com/langgenius/dify.git
Merge 42208edd08 into a00ac1b5b1
This commit is contained in:
commit
bc66d05403
|
|
@ -515,6 +515,7 @@ class DatasetRetrieval:
|
|||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
dataset_count = len(available_datasets)
|
||||
with measure_time() as timer:
|
||||
if query:
|
||||
query_thread = threading.Thread(
|
||||
|
|
@ -534,6 +535,7 @@ class DatasetRetrieval:
|
|||
"score_threshold": score_threshold,
|
||||
"query": query,
|
||||
"attachment_id": None,
|
||||
"dataset_count": dataset_count,
|
||||
},
|
||||
)
|
||||
all_threads.append(query_thread)
|
||||
|
|
@ -557,6 +559,7 @@ class DatasetRetrieval:
|
|||
"score_threshold": score_threshold,
|
||||
"query": None,
|
||||
"attachment_id": attachment_id,
|
||||
"dataset_count": dataset_count,
|
||||
},
|
||||
)
|
||||
all_threads.append(attachment_thread)
|
||||
|
|
@ -1404,6 +1407,7 @@ class DatasetRetrieval:
|
|||
score_threshold: float,
|
||||
query: str | None,
|
||||
attachment_id: str | None,
|
||||
dataset_count: int,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
threads = []
|
||||
|
|
@ -1439,7 +1443,8 @@ class DatasetRetrieval:
|
|||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if reranking_enable:
|
||||
# Skip second reranking when there is only one dataset
|
||||
if reranking_enable and dataset_count > 1:
|
||||
# do rerank for searched documents
|
||||
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
|
||||
if query:
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ import pytest
|
|||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
|
@ -1507,6 +1508,282 @@ class TestRetrievalService:
|
|||
call_kwargs = mock_retrieve.call_args.kwargs
|
||||
assert call_kwargs["reranking_model"] == reranking_model
|
||||
|
||||
# ==================== Multiple Retrieve Thread Tests ====================
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||
def test_multiple_retrieve_thread_skips_second_reranking_with_single_dataset(
|
||||
self, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||
):
|
||||
"""
|
||||
Test that _multiple_retrieve_thread skips second reranking when dataset_count is 1.
|
||||
|
||||
When there is only one dataset, the second reranking is unnecessary
|
||||
because the documents are already ranked from the first retrieval.
|
||||
This optimization avoids the overhead of reranking when it won't
|
||||
provide any benefit.
|
||||
|
||||
Verifies:
|
||||
- DataPostProcessor is NOT called when dataset_count == 1
|
||||
- Documents are still added to all_documents
|
||||
- Standard scoring logic is applied instead
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Create test documents
|
||||
doc1 = Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
doc2 = Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
# Mock _retriever to return documents
|
||||
def side_effect_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.extend([doc1, doc2])
|
||||
|
||||
mock_retriever.side_effect = side_effect_retriever
|
||||
|
||||
# Set up dataset with high_quality indexing
|
||||
mock_dataset.indexing_technique = "high_quality"
|
||||
|
||||
all_documents = []
|
||||
|
||||
# Act - Call with dataset_count = 1
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=mock_flask_app,
|
||||
available_datasets=[mock_dataset],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
weights=None,
|
||||
top_k=5,
|
||||
score_threshold=0.5,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=1, # Single dataset - should skip second reranking
|
||||
)
|
||||
|
||||
# Assert
|
||||
# DataPostProcessor should NOT be called (second reranking skipped)
|
||||
mock_data_processor_class.assert_not_called()
|
||||
|
||||
# Documents should still be added to all_documents
|
||||
assert len(all_documents) == 2
|
||||
assert all_documents[0].page_content == "Test content 1"
|
||||
assert all_documents[1].page_content == "Test content 2"
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
|
||||
def test_multiple_retrieve_thread_performs_second_reranking_with_multiple_datasets(
|
||||
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||
):
|
||||
"""
|
||||
Test that _multiple_retrieve_thread performs second reranking when dataset_count > 1.
|
||||
|
||||
When there are multiple datasets, the second reranking is necessary
|
||||
to merge and re-rank results from different datasets. This ensures
|
||||
the most relevant documents across all datasets are returned.
|
||||
|
||||
Verifies:
|
||||
- DataPostProcessor IS called when dataset_count > 1
|
||||
- Reranking is applied with correct parameters
|
||||
- Documents are processed correctly
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Create test documents
|
||||
doc1 = Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.7, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
doc2 = Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.6, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
# Mock _retriever to return documents
|
||||
def side_effect_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.extend([doc1, doc2])
|
||||
|
||||
mock_retriever.side_effect = side_effect_retriever
|
||||
|
||||
# Set up dataset with high_quality indexing
|
||||
mock_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# Mock DataPostProcessor instance and its invoke method
|
||||
mock_processor_instance = Mock()
|
||||
# Simulate reranking - return documents in reversed order with updated scores
|
||||
reranked_docs = [
|
||||
Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
),
|
||||
Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.85, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
),
|
||||
]
|
||||
mock_processor_instance.invoke.return_value = reranked_docs
|
||||
mock_data_processor_class.return_value = mock_processor_instance
|
||||
|
||||
all_documents = []
|
||||
|
||||
# Create second dataset
|
||||
mock_dataset2 = Mock(spec=Dataset)
|
||||
mock_dataset2.id = str(uuid4())
|
||||
mock_dataset2.indexing_technique = "high_quality"
|
||||
mock_dataset2.provider = "dify"
|
||||
|
||||
# Act - Call with dataset_count = 2
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=mock_flask_app,
|
||||
available_datasets=[mock_dataset, mock_dataset2],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
weights=None,
|
||||
top_k=5,
|
||||
score_threshold=0.5,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=2, # Multiple datasets - should perform second reranking
|
||||
)
|
||||
|
||||
# Assert
|
||||
# DataPostProcessor SHOULD be called (second reranking performed)
|
||||
mock_data_processor_class.assert_called_once_with(
|
||||
tenant_id,
|
||||
"reranking_model",
|
||||
{"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
# Verify invoke was called with correct parameters
|
||||
mock_processor_instance.invoke.assert_called_once()
|
||||
|
||||
# Documents should be added to all_documents after reranking
|
||||
assert len(all_documents) == 2
|
||||
# The reranked order should be reflected
|
||||
assert all_documents[0].page_content == "Test content 2"
|
||||
assert all_documents[1].page_content == "Test content 1"
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval._retriever")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.calculate_vector_score")
|
||||
def test_multiple_retrieve_thread_single_dataset_uses_standard_scoring(
|
||||
self, mock_calculate_vector_score, mock_retriever, mock_data_processor_class, mock_flask_app, mock_dataset
|
||||
):
|
||||
"""
|
||||
Test that _multiple_retrieve_thread uses standard scoring when dataset_count is 1
|
||||
and reranking is enabled.
|
||||
|
||||
When there's only one dataset, instead of using DataPostProcessor,
|
||||
the method should fall through to the standard scoring logic
|
||||
(calculate_vector_score for high_quality datasets).
|
||||
|
||||
Verifies:
|
||||
- DataPostProcessor is NOT called
|
||||
- calculate_vector_score IS called for high_quality indexing
|
||||
- Documents are scored correctly
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
tenant_id = str(uuid4())
|
||||
|
||||
# Create test documents
|
||||
doc1 = Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.9, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
doc2 = Document(
|
||||
page_content="Test content 2",
|
||||
metadata={"doc_id": "doc2", "score": 0.8, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
)
|
||||
|
||||
# Mock _retriever to return documents
|
||||
def side_effect_retriever(
|
||||
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
|
||||
):
|
||||
all_documents.extend([doc1, doc2])
|
||||
|
||||
mock_retriever.side_effect = side_effect_retriever
|
||||
|
||||
# Set up dataset with high_quality indexing
|
||||
mock_dataset.indexing_technique = "high_quality"
|
||||
|
||||
# Mock calculate_vector_score to return scored documents
|
||||
scored_docs = [
|
||||
Document(
|
||||
page_content="Test content 1",
|
||||
metadata={"doc_id": "doc1", "score": 0.95, "document_id": str(uuid4()), "dataset_id": mock_dataset.id},
|
||||
provider="dify",
|
||||
),
|
||||
]
|
||||
mock_calculate_vector_score.return_value = scored_docs
|
||||
|
||||
all_documents = []
|
||||
|
||||
# Act - Call with dataset_count = 1
|
||||
dataset_retrieval._multiple_retrieve_thread(
|
||||
flask_app=mock_flask_app,
|
||||
available_datasets=[mock_dataset],
|
||||
metadata_condition=None,
|
||||
metadata_filter_document_ids=None,
|
||||
all_documents=all_documents,
|
||||
tenant_id=tenant_id,
|
||||
reranking_enable=True, # Reranking enabled but should be skipped for single dataset
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
weights=None,
|
||||
top_k=5,
|
||||
score_threshold=0.5,
|
||||
query="test query",
|
||||
attachment_id=None,
|
||||
dataset_count=1,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# DataPostProcessor should NOT be called
|
||||
mock_data_processor_class.assert_not_called()
|
||||
|
||||
# calculate_vector_score SHOULD be called for high_quality datasets
|
||||
mock_calculate_vector_score.assert_called_once()
|
||||
call_args = mock_calculate_vector_score.call_args
|
||||
assert call_args[0][1] == 5 # top_k
|
||||
|
||||
# Documents should be added after standard scoring
|
||||
assert len(all_documents) == 1
|
||||
assert all_documents[0].page_content == "Test content 1"
|
||||
|
||||
|
||||
class TestRetrievalMethods:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue