mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 13:01:16 +08:00
refactor: pass session into hit testing service (#37785)
This commit is contained in:
parent
b3e5f29421
commit
7fc8eed716
@ -26,6 +26,7 @@ from controllers.console.wraps import (
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
@ -390,6 +391,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
|
||||
try:
|
||||
response = HitTestingService.external_retrieve(
|
||||
session=db.session,
|
||||
dataset=dataset,
|
||||
query=payload.query,
|
||||
account=current_user,
|
||||
|
||||
@ -18,6 +18,7 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import resolve_account_fallback
|
||||
from models.account import Account
|
||||
@ -115,6 +116,7 @@ class DatasetsHitTestingBase:
|
||||
try:
|
||||
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
|
||||
response = HitTestingService.retrieve(
|
||||
session=db.session,
|
||||
dataset=dataset,
|
||||
query=cast(str, args.get("query")),
|
||||
account=current_user,
|
||||
|
||||
@ -4,6 +4,7 @@ import time
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
|
||||
@ -12,7 +13,6 @@ from core.rag.index_processor.constant.query_type import QueryType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
from models import Account
|
||||
from models.dataset import Dataset, DatasetQuery
|
||||
@ -56,7 +56,9 @@ class HitTestingService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _dump_retrieval_records(cls, records: list[RetrievalSegments]) -> list[dict[str, Any]]:
|
||||
def _dump_retrieval_records(
|
||||
cls, session: Session | scoped_session, records: list[RetrievalSegments]
|
||||
) -> list[dict[str, Any]]:
|
||||
document_ids = {
|
||||
document_id
|
||||
for record in records
|
||||
@ -69,9 +71,7 @@ class HitTestingService:
|
||||
|
||||
documents = {
|
||||
document.id: cls._dump_dataset_document(document)
|
||||
for document in db.session.scalars(
|
||||
select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
|
||||
).all()
|
||||
for document in session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
|
||||
}
|
||||
|
||||
records_with_documents: list[dict[str, Any]] = []
|
||||
@ -105,6 +105,7 @@ class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
session: Session | scoped_session,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
@ -142,7 +143,7 @@ class HitTestingService:
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return cls.compact_retrieve_response(query, [])
|
||||
return cls.compact_retrieve_response(session, query, [])
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrieval_method=RetrievalMethod(
|
||||
resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
|
||||
@ -181,14 +182,15 @@ class HitTestingService:
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
session.add(dataset_query)
|
||||
session.commit()
|
||||
|
||||
return cls.compact_retrieve_response(query, all_documents)
|
||||
return cls.compact_retrieve_response(session, query, all_documents)
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(
|
||||
cls,
|
||||
session: Session | scoped_session,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
@ -222,20 +224,22 @@ class HitTestingService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
session.add(dataset_query)
|
||||
session.commit()
|
||||
|
||||
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
|
||||
|
||||
@classmethod
|
||||
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
|
||||
def compact_retrieve_response(
|
||||
cls, session: Session | scoped_session, query: str, documents: list[Document]
|
||||
) -> RetrieveResponseDict:
|
||||
records = RetrievalService.format_retrieval_documents(documents)
|
||||
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
},
|
||||
"records": cls._dump_retrieval_records(records),
|
||||
"records": cls._dump_retrieval_records(session, records),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -181,7 +181,9 @@ class TestHitTestingService:
|
||||
# ── Response formatting ────────────────────────────────────────────
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None:
|
||||
def test_compact_retrieve_response_should_format_correctly(
|
||||
self, mock_format: MagicMock, db_session_with_containers: Session
|
||||
) -> None:
|
||||
query = "test query"
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
|
||||
@ -189,7 +191,9 @@ class TestHitTestingService:
|
||||
mock_record.model_dump.return_value = {"content": "formatted content"}
|
||||
mock_format.return_value = [mock_record]
|
||||
|
||||
response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc]))
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.compact_retrieve_response(db_session_with_containers, query, [mock_doc])
|
||||
)
|
||||
|
||||
assert response.query.content == query
|
||||
assert len(response.records) == 1
|
||||
@ -242,6 +246,7 @@ class TestHitTestingService:
|
||||
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.external_retrieve(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
query='test "query"',
|
||||
account=account,
|
||||
@ -269,7 +274,9 @@ class TestHitTestingService:
|
||||
dataset = _create_dataset(db_session_with_containers, provider="vendor")
|
||||
account = MagicMock()
|
||||
|
||||
response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account))
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.external_retrieve(db_session_with_containers, dataset, "test query", account)
|
||||
)
|
||||
|
||||
assert response.query.content == "test query"
|
||||
assert response.records == []
|
||||
@ -292,6 +299,7 @@ class TestHitTestingService:
|
||||
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.retrieve(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
@ -320,7 +328,11 @@ class TestHitTestingService:
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"metadata_filtering_conditions": {"some": "condition"},
|
||||
"metadata_filtering_conditions": {
|
||||
"conditions": [
|
||||
{"name": "category", "comparison_operator": "is", "value": "test"},
|
||||
],
|
||||
},
|
||||
"top_k": 5,
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
@ -330,6 +342,7 @@ class TestHitTestingService:
|
||||
mock_retrieve.return_value = retrieved_documents
|
||||
|
||||
HitTestingService.retrieve(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
@ -352,7 +365,11 @@ class TestHitTestingService:
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"metadata_filtering_conditions": {"some": "condition"},
|
||||
"metadata_filtering_conditions": {
|
||||
"conditions": [
|
||||
{"name": "category", "comparison_operator": "is", "value": "test"},
|
||||
],
|
||||
},
|
||||
"top_k": 5,
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
@ -362,6 +379,7 @@ class TestHitTestingService:
|
||||
|
||||
response = _RetrieveResponse.model_validate(
|
||||
HitTestingService.retrieve(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
@ -393,6 +411,7 @@ class TestHitTestingService:
|
||||
mock_retrieve.return_value = retrieved_documents
|
||||
|
||||
HitTestingService.retrieve(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
@ -452,6 +471,7 @@ class TestHitTestingService:
|
||||
mock_retrieve.return_value = retrieved_documents
|
||||
|
||||
HitTestingService.retrieve(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
query="test query",
|
||||
account=account,
|
||||
@ -477,11 +497,15 @@ class TestHitTestingService:
|
||||
"doc_metadata": {"source": "manual"},
|
||||
}
|
||||
|
||||
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None:
|
||||
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
segment = _build_segment(document_id="")
|
||||
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
|
||||
|
||||
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
|
||||
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(
|
||||
HitTestingService._dump_retrieval_records(db_session_with_containers, [record])
|
||||
)
|
||||
|
||||
assert len(records) == 1
|
||||
assert records[0].segment.id == segment.id
|
||||
@ -493,7 +517,9 @@ class TestHitTestingService:
|
||||
segment = _create_segment(db_session_with_containers, document=document)
|
||||
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9})
|
||||
|
||||
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
|
||||
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(
|
||||
HitTestingService._dump_retrieval_records(db_session_with_containers, [record])
|
||||
)
|
||||
|
||||
assert len(records) == 1
|
||||
dumped_segment = records[0].segment
|
||||
@ -515,7 +541,7 @@ class TestHitTestingService:
|
||||
segment = _create_segment(db_session_with_containers)
|
||||
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
|
||||
|
||||
result = HitTestingService._dump_retrieval_records([record])
|
||||
result = HitTestingService._dump_retrieval_records(db_session_with_containers, [record])
|
||||
|
||||
assert result == []
|
||||
assert "Skipping hit-testing records with missing documents" in caplog.text
|
||||
|
||||
@ -147,8 +147,7 @@ class TestHitTestingServiceRetrieve:
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
return MagicMock()
|
||||
|
||||
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
|
||||
"""
|
||||
@ -186,7 +185,9 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
result = HitTestingService.retrieve(
|
||||
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
@ -232,7 +233,9 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
result = HitTestingService.retrieve(
|
||||
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
@ -257,9 +260,11 @@ class TestHitTestingServiceRetrieve:
|
||||
retrieval_model = {
|
||||
"metadata_filtering_conditions": {
|
||||
"conditions": [
|
||||
{"field": "category", "operator": "is", "value": "test"},
|
||||
{"name": "category", "comparison_operator": "is", "value": "test"},
|
||||
],
|
||||
},
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
external_retrieval_model = {}
|
||||
|
||||
@ -286,7 +291,9 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
result = HitTestingService.retrieve(
|
||||
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
@ -308,9 +315,11 @@ class TestHitTestingServiceRetrieve:
|
||||
retrieval_model = {
|
||||
"metadata_filtering_conditions": {
|
||||
"conditions": [
|
||||
{"field": "category", "operator": "is", "value": "test"},
|
||||
{"name": "category", "comparison_operator": "is", "value": "test"},
|
||||
],
|
||||
},
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
external_retrieval_model = {}
|
||||
|
||||
@ -327,7 +336,9 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_format.return_value = []
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
result = HitTestingService.retrieve(
|
||||
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
@ -344,6 +355,8 @@ class TestHitTestingServiceRetrieve:
|
||||
dataset_retrieval_model = {
|
||||
"search_method": RetrievalMethod.HYBRID_SEARCH,
|
||||
"top_k": 3,
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=dataset_retrieval_model)
|
||||
account = HitTestingTestDataFactory.create_user_mock()
|
||||
@ -366,7 +379,9 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
|
||||
result = HitTestingService.retrieve(
|
||||
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
@ -391,8 +406,7 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
return MagicMock()
|
||||
|
||||
def test_external_retrieve_success(self, mock_db_session):
|
||||
"""
|
||||
@ -424,7 +438,7 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
@ -455,7 +469,7 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
@ -490,7 +504,7 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
@ -524,7 +538,7 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
|
||||
# Act
|
||||
result = HitTestingService.external_retrieve(
|
||||
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
|
||||
)
|
||||
|
||||
# Assert
|
||||
@ -565,7 +579,7 @@ class TestHitTestingServiceCompactRetrieveResponse:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
result = HitTestingService.compact_retrieve_response(query, documents)
|
||||
result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
@ -591,7 +605,7 @@ class TestHitTestingServiceCompactRetrieveResponse:
|
||||
mock_format.return_value = []
|
||||
|
||||
# Act
|
||||
result = HitTestingService.compact_retrieve_response(query, documents)
|
||||
result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents)
|
||||
|
||||
# Assert
|
||||
assert result["query"]["content"] == query
|
||||
@ -708,7 +722,7 @@ class TestHitTestingServiceHitTestingArgsCheck:
|
||||
args = {"query": ""}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
|
||||
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_none_query(self):
|
||||
@ -721,7 +735,7 @@ class TestHitTestingServiceHitTestingArgsCheck:
|
||||
args = {"query": None}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
|
||||
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_too_long_query(self):
|
||||
@ -734,7 +748,7 @@ class TestHitTestingServiceHitTestingArgsCheck:
|
||||
args = {"query": "a" * 251}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
|
||||
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_exactly_250_characters(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user