mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
test: migrate hit_testing_service tests to testcontainers (#34750)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
28b8215c9b
commit
e224c77920
@ -1,239 +1,193 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class TestHitTestingService:
|
||||
"""Test suite for HitTestingService"""
|
||||
def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset:
|
||||
tenant_id = str(uuid4())
|
||||
created_by = str(uuid4())
|
||||
ds = Dataset(
|
||||
tenant_id=kwargs.get("tenant_id", tenant_id),
|
||||
name=kwargs.get("name", "test-dataset"),
|
||||
created_by=kwargs.get("created_by", created_by),
|
||||
provider=provider,
|
||||
)
|
||||
db_session.add(ds)
|
||||
db_session.commit()
|
||||
db_session.refresh(ds)
|
||||
return ds
|
||||
|
||||
# ===== Utility Method Tests =====
|
||||
|
||||
class TestHitTestingService:
|
||||
# ── Utility methods (pure logic, no DB) ────────────────────────────
|
||||
|
||||
def test_escape_query_for_search_should_escape_double_quotes(self):
|
||||
"""Test that escape_query_for_search escapes double quotes correctly"""
|
||||
# Arrange
|
||||
query = 'test "query" with quotes'
|
||||
expected = 'test \\"query\\" with quotes'
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
assert result == 'test \\"query\\" with quotes'
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_query(self):
|
||||
"""Test that hit_testing_args_check passes with a valid query"""
|
||||
# Arrange
|
||||
args = {"query": "valid query"}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
HitTestingService.hit_testing_args_check({"query": "valid query"})
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
|
||||
"""Test that hit_testing_args_check passes with valid attachment_ids"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": ["id1", "id2"]}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
|
||||
# Arrange
|
||||
args = {}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query or attachment_ids is required" in str(exc_info.value)
|
||||
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
|
||||
HitTestingService.hit_testing_args_check({})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
|
||||
# Arrange
|
||||
args = {"query": "a" * 251}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query cannot exceed 250 characters" in str(exc_info.value)
|
||||
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check({"query": "a" * 251})
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": "not a list"}
|
||||
with pytest.raises(ValueError, match="Attachment_ids must be a list"):
|
||||
HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"})
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Attachment_ids must be a list" in str(exc_info.value)
|
||||
|
||||
# ===== Response Formatting Tests =====
|
||||
# ── Response formatting ────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
|
||||
"""Test that compact_retrieve_response formats the response correctly"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
documents = [mock_doc]
|
||||
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"content": "formatted content"}
|
||||
mock_format.return_value = [mock_record]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc]))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert len(result["records"]) == 1
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
|
||||
mock_format.assert_called_once_with(documents)
|
||||
mock_format.assert_called_once_with([mock_doc])
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "external"
|
||||
query = "test query"
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
documents = [
|
||||
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
|
||||
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
|
||||
]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
result = cast(
|
||||
dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert len(result["records"]) == 2
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
|
||||
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
documents = [{"content": "c1"}]
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== External Retrieve Tests =====
|
||||
# ── External retrieve (real DB) ────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
|
||||
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
dataset.provider = "external"
|
||||
query = 'test "query"'
|
||||
def test_external_retrieve_should_succeed_for_external_provider(
|
||||
self, mock_ext_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="external")
|
||||
account_id = str(uuid4())
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
account.id = account_id
|
||||
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
|
||||
|
||||
# Act
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query='test "query"',
|
||||
account=account,
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == 'test "query"'
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
|
||||
|
||||
# Verify call to RetrievalService.external_retrieve with escaped query
|
||||
mock_ext_retrieve.assert_called_once_with(
|
||||
dataset_id="dataset_id",
|
||||
dataset_id=dataset.id,
|
||||
query='test \\"query\\"',
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
)
|
||||
|
||||
# Verify DatasetQuery record was added and committed
|
||||
mock_add.assert_called_once()
|
||||
mock_commit.assert_called_once()
|
||||
db_session_with_containers.expire_all()
|
||||
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
assert after_count == before_count + 1
|
||||
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that external_retrieve returns empty results immediately if provider is not external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
account = MagicMock()
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== Retrieve Tests =====
|
||||
# ── Retrieve (real DB) ─────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve uses default model when retrieval_model is not provided"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
def test_retrieve_should_use_default_model_when_none_provided(
|
||||
self, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
dataset.retrieval_model = None
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
account.id = str(uuid4())
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
|
||||
dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={}
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["query"])["content"] == "test query"
|
||||
mock_retrieve.assert_called_once()
|
||||
# Verify top_k from default_retrieval_model (4)
|
||||
assert mock_retrieve.call_args.kwargs["top_k"] == 4
|
||||
mock_commit.assert_called_once()
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0
|
||||
assert after_count == before_count + 1
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_metadata_filtering(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
@ -242,29 +196,27 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response
|
||||
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
|
||||
mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string")
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_get_meta.assert_called_once()
|
||||
mock_retrieve.assert_called_once()
|
||||
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(
|
||||
self, mock_get_meta, mock_retrieve, db_session_with_containers: Session
|
||||
):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
|
||||
retrieval_model = {
|
||||
@ -274,37 +226,27 @@ class TestHitTestingService:
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response: condition returned but no IDs
|
||||
mock_get_meta.return_value = ({}, "condition_string")
|
||||
|
||||
# Act
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
mock_retrieve.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
attachment_ids = ["att1", "att2"]
|
||||
|
||||
retrieval_model = {
|
||||
@ -315,21 +257,19 @@ class TestHitTestingService:
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once_with(
|
||||
retrieval_method=ANY,
|
||||
dataset_id="dataset_id",
|
||||
query=query,
|
||||
dataset_id=dataset.id,
|
||||
query="test query",
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=4,
|
||||
score_threshold=0.0,
|
||||
@ -338,26 +278,27 @@ class TestHitTestingService:
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
)
|
||||
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
|
||||
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
|
||||
called_query = mock_add.call_args[0][0]
|
||||
query_content = json.loads(called_query.content)
|
||||
|
||||
# Verify DatasetQuery was persisted with correct content structure
|
||||
db_session_with_containers.expire_all()
|
||||
latest = db_session_with_containers.scalar(
|
||||
select(DatasetQuery)
|
||||
.where(DatasetQuery.dataset_id == dataset.id)
|
||||
.order_by(DatasetQuery.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
assert latest is not None
|
||||
query_content = json.loads(latest.content)
|
||||
assert len(query_content) == 3 # 1 text + 2 images
|
||||
assert query_content[0]["content_type"] == "text_query"
|
||||
assert query_content[1]["content_type"] == "image_query"
|
||||
assert query_content[1]["content"] == "att1"
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve passes reranking and threshold parameters correctly"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session):
|
||||
dataset = _create_dataset(db_session_with_containers)
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
account.id = str(uuid4())
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "hybrid_search",
|
||||
@ -371,12 +312,14 @@ class TestHitTestingService:
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once()
|
||||
kwargs = mock_retrieve.call_args.kwargs
|
||||
assert kwargs["score_threshold"] == 0.5
|
||||
Loading…
Reference in New Issue
Block a user