refactor: pass session into hit testing service (#37785)

This commit is contained in:
Myshkin451 2026-06-23 14:21:38 +08:00 committed by GitHub
parent b3e5f29421
commit 7fc8eed716
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 90 additions and 42 deletions

View File

@ -26,6 +26,7 @@ from controllers.console.wraps import (
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.dataset_fields import (
dataset_detail_fields,
@ -390,6 +391,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
try:
response = HitTestingService.external_retrieve(
session=db.session,
dataset=dataset,
query=payload.query,
account=current_user,

View File

@ -18,6 +18,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import resolve_account_fallback
from models.account import Account
@ -115,6 +116,7 @@ class DatasetsHitTestingBase:
try:
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
response = HitTestingService.retrieve(
session=db.session,
dataset=dataset,
query=cast(str, args.get("query")),
account=current_user,

View File

@ -4,6 +4,7 @@ import time
from typing import Any, TypedDict, cast
from sqlalchemy import select
from sqlalchemy.orm import Session, scoped_session
from core.app.app_config.entities import ModelConfig
from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
@ -12,7 +13,6 @@ from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from graphon.model_runtime.entities import LLMMode
from models import Account
from models.dataset import Dataset, DatasetQuery
@ -56,7 +56,9 @@ class HitTestingService:
}
@classmethod
def _dump_retrieval_records(cls, records: list[RetrievalSegments]) -> list[dict[str, Any]]:
def _dump_retrieval_records(
cls, session: Session | scoped_session, records: list[RetrievalSegments]
) -> list[dict[str, Any]]:
document_ids = {
document_id
for record in records
@ -69,9 +71,7 @@ class HitTestingService:
documents = {
document.id: cls._dump_dataset_document(document)
for document in db.session.scalars(
select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))
).all()
for document in session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all()
}
records_with_documents: list[dict[str, Any]] = []
@ -105,6 +105,7 @@ class HitTestingService:
@classmethod
def retrieve(
cls,
session: Session | scoped_session,
dataset: Dataset,
query: str,
account: Account,
@ -142,7 +143,7 @@ class HitTestingService:
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
return cls.compact_retrieve_response(session, query, [])
all_documents = RetrievalService.retrieve(
retrieval_method=RetrievalMethod(
resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
@ -181,14 +182,15 @@ class HitTestingService:
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
session.add(dataset_query)
session.commit()
return cls.compact_retrieve_response(query, all_documents)
return cls.compact_retrieve_response(session, query, all_documents)
@classmethod
def external_retrieve(
cls,
session: Session | scoped_session,
dataset: Dataset,
query: str,
account: Account,
@ -222,20 +224,22 @@ class HitTestingService:
created_by=account.id,
)
db.session.add(dataset_query)
db.session.commit()
session.add(dataset_query)
session.commit()
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod
def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict:
def compact_retrieve_response(
cls, session: Session | scoped_session, query: str, documents: list[Document]
) -> RetrieveResponseDict:
records = RetrievalService.format_retrieval_documents(documents)
return {
"query": {
"content": query,
},
"records": cls._dump_retrieval_records(records),
"records": cls._dump_retrieval_records(session, records),
}
@classmethod

View File

@ -181,7 +181,9 @@ class TestHitTestingService:
# ── Response formatting ────────────────────────────────────────────
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None:
def test_compact_retrieve_response_should_format_correctly(
self, mock_format: MagicMock, db_session_with_containers: Session
) -> None:
query = "test query"
mock_doc = MagicMock(spec=Document)
@ -189,7 +191,9 @@ class TestHitTestingService:
mock_record.model_dump.return_value = {"content": "formatted content"}
mock_format.return_value = [mock_record]
response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc]))
response = _RetrieveResponse.model_validate(
HitTestingService.compact_retrieve_response(db_session_with_containers, query, [mock_doc])
)
assert response.query.content == query
assert len(response.records) == 1
@ -242,6 +246,7 @@ class TestHitTestingService:
response = _RetrieveResponse.model_validate(
HitTestingService.external_retrieve(
db_session_with_containers,
dataset=dataset,
query='test "query"',
account=account,
@ -269,7 +274,9 @@ class TestHitTestingService:
dataset = _create_dataset(db_session_with_containers, provider="vendor")
account = MagicMock()
response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account))
response = _RetrieveResponse.model_validate(
HitTestingService.external_retrieve(db_session_with_containers, dataset, "test query", account)
)
assert response.query.content == "test query"
assert response.records == []
@ -292,6 +299,7 @@ class TestHitTestingService:
response = _RetrieveResponse.model_validate(
HitTestingService.retrieve(
db_session_with_containers,
dataset=dataset,
query="test query",
account=account,
@ -320,7 +328,11 @@ class TestHitTestingService:
retrieval_model = {
"search_method": "semantic_search",
"metadata_filtering_conditions": {"some": "condition"},
"metadata_filtering_conditions": {
"conditions": [
{"name": "category", "comparison_operator": "is", "value": "test"},
],
},
"top_k": 5,
"reranking_enable": False,
"score_threshold_enabled": False,
@ -330,6 +342,7 @@ class TestHitTestingService:
mock_retrieve.return_value = retrieved_documents
HitTestingService.retrieve(
db_session_with_containers,
dataset=dataset,
query="test query",
account=account,
@ -352,7 +365,11 @@ class TestHitTestingService:
retrieval_model = {
"search_method": "semantic_search",
"metadata_filtering_conditions": {"some": "condition"},
"metadata_filtering_conditions": {
"conditions": [
{"name": "category", "comparison_operator": "is", "value": "test"},
],
},
"top_k": 5,
"reranking_enable": False,
"score_threshold_enabled": False,
@ -362,6 +379,7 @@ class TestHitTestingService:
response = _RetrieveResponse.model_validate(
HitTestingService.retrieve(
db_session_with_containers,
dataset=dataset,
query="test query",
account=account,
@ -393,6 +411,7 @@ class TestHitTestingService:
mock_retrieve.return_value = retrieved_documents
HitTestingService.retrieve(
db_session_with_containers,
dataset=dataset,
query="test query",
account=account,
@ -452,6 +471,7 @@ class TestHitTestingService:
mock_retrieve.return_value = retrieved_documents
HitTestingService.retrieve(
db_session_with_containers,
dataset=dataset,
query="test query",
account=account,
@ -477,11 +497,15 @@ class TestHitTestingService:
"doc_metadata": {"source": "manual"},
}
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None:
def test_dump_retrieval_records_returns_dumped_records_without_document_ids(
self, db_session_with_containers: Session
) -> None:
segment = _build_segment(document_id="")
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(
HitTestingService._dump_retrieval_records(db_session_with_containers, [record])
)
assert len(records) == 1
assert records[0].segment.id == segment.id
@ -493,7 +517,9 @@ class TestHitTestingService:
segment = _create_segment(db_session_with_containers, document=document)
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9})
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record]))
records = _DUMPED_RETRIEVAL_RECORDS.validate_python(
HitTestingService._dump_retrieval_records(db_session_with_containers, [record])
)
assert len(records) == 1
dumped_segment = records[0].segment
@ -515,7 +541,7 @@ class TestHitTestingService:
segment = _create_segment(db_session_with_containers)
record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95})
result = HitTestingService._dump_retrieval_records([record])
result = HitTestingService._dump_retrieval_records(db_session_with_containers, [record])
assert result == []
assert "Skipping hit-testing records with missing documents" in caplog.text

View File

@ -147,8 +147,7 @@ class TestHitTestingServiceRetrieve:
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
yield mock_db
return MagicMock()
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
"""
@ -186,7 +185,9 @@ class TestHitTestingServiceRetrieve:
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
result = HitTestingService.retrieve(
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
)
# Assert
assert result["query"]["content"] == query
@ -232,7 +233,9 @@ class TestHitTestingServiceRetrieve:
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
result = HitTestingService.retrieve(
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
)
# Assert
assert result["query"]["content"] == query
@ -257,9 +260,11 @@ class TestHitTestingServiceRetrieve:
retrieval_model = {
"metadata_filtering_conditions": {
"conditions": [
{"field": "category", "operator": "is", "value": "test"},
{"name": "category", "comparison_operator": "is", "value": "test"},
],
},
"reranking_enable": False,
"score_threshold_enabled": False,
}
external_retrieval_model = {}
@ -286,7 +291,9 @@ class TestHitTestingServiceRetrieve:
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
result = HitTestingService.retrieve(
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
)
# Assert
assert result["query"]["content"] == query
@ -308,9 +315,11 @@ class TestHitTestingServiceRetrieve:
retrieval_model = {
"metadata_filtering_conditions": {
"conditions": [
{"field": "category", "operator": "is", "value": "test"},
{"name": "category", "comparison_operator": "is", "value": "test"},
],
},
"reranking_enable": False,
"score_threshold_enabled": False,
}
external_retrieval_model = {}
@ -327,7 +336,9 @@ class TestHitTestingServiceRetrieve:
mock_format.return_value = []
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
result = HitTestingService.retrieve(
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
)
# Assert
assert result["query"]["content"] == query
@ -344,6 +355,8 @@ class TestHitTestingServiceRetrieve:
dataset_retrieval_model = {
"search_method": RetrievalMethod.HYBRID_SEARCH,
"top_k": 3,
"reranking_enable": False,
"score_threshold_enabled": False,
}
dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=dataset_retrieval_model)
account = HitTestingTestDataFactory.create_user_mock()
@ -366,7 +379,9 @@ class TestHitTestingServiceRetrieve:
mock_format.return_value = mock_records
# Act
result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model)
result = HitTestingService.retrieve(
mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model
)
# Assert
assert result["query"]["content"] == query
@ -391,8 +406,7 @@ class TestHitTestingServiceExternalRetrieve:
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
yield mock_db
return MagicMock()
def test_external_retrieve_success(self, mock_db_session):
"""
@ -424,7 +438,7 @@ class TestHitTestingServiceExternalRetrieve:
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
@ -455,7 +469,7 @@ class TestHitTestingServiceExternalRetrieve:
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
@ -490,7 +504,7 @@ class TestHitTestingServiceExternalRetrieve:
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
@ -524,7 +538,7 @@ class TestHitTestingServiceExternalRetrieve:
# Act
result = HitTestingService.external_retrieve(
dataset, query, account, external_retrieval_model, metadata_filtering_conditions
mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions
)
# Assert
@ -565,7 +579,7 @@ class TestHitTestingServiceCompactRetrieveResponse:
mock_format.return_value = mock_records
# Act
result = HitTestingService.compact_retrieve_response(query, documents)
result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents)
# Assert
assert result["query"]["content"] == query
@ -591,7 +605,7 @@ class TestHitTestingServiceCompactRetrieveResponse:
mock_format.return_value = []
# Act
result = HitTestingService.compact_retrieve_response(query, documents)
result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents)
# Assert
assert result["query"]["content"] == query
@ -708,7 +722,7 @@ class TestHitTestingServiceHitTestingArgsCheck:
args = {"query": ""}
# Act & Assert
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_none_query(self):
@ -721,7 +735,7 @@ class TestHitTestingServiceHitTestingArgsCheck:
args = {"query": None}
# Act & Assert
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
with pytest.raises(ValueError, match="Query or attachment_ids is required"):
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_too_long_query(self):
@ -734,7 +748,7 @@ class TestHitTestingServiceHitTestingArgsCheck:
args = {"query": "a" * 251}
# Act & Assert
with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"):
with pytest.raises(ValueError, match="Query cannot exceed 250 characters"):
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_exactly_250_characters(self):