From 7fc8eed7164fdbefc697e2bac61a8c2e5098b5dc Mon Sep 17 00:00:00 2001 From: Myshkin451 <79880574+myshkin451@users.noreply.github.com> Date: Tue, 23 Jun 2026 14:21:38 +0800 Subject: [PATCH] refactor: pass session into hit testing service (#37785) --- api/controllers/console/datasets/external.py | 2 + .../console/datasets/hit_testing_base.py | 2 + api/services/hit_testing_service.py | 30 ++++++----- .../services/test_hit_testing_service.py | 44 +++++++++++---- api/tests/unit_tests/services/hit_service.py | 54 ++++++++++++------- 5 files changed, 90 insertions(+), 42 deletions(-) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 033c9a69af6..eb7b9aa84f8 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -26,6 +26,7 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from extensions.ext_database import db from fields.base import ResponseModel from fields.dataset_fields import ( dataset_detail_fields, @@ -390,6 +391,7 @@ class ExternalKnowledgeHitTestingApi(Resource): try: response = HitTestingService.external_retrieve( + session=db.session, dataset=dataset, query=payload.query, account=current_user, diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 4e90e66eb25..c343effa9a1 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -18,6 +18,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.login import resolve_account_fallback from models.account import Account @@ -115,6 +116,7 @@ class DatasetsHitTestingBase: try: current_user, _ = resolve_account_fallback(current_user, current_tenant_id) response = HitTestingService.retrieve( + session=db.session, dataset=dataset, query=cast(str, args.get("query")), account=current_user, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 6c70e324b17..9a2843864d9 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,6 +4,7 @@ import time from typing import Any, TypedDict, cast from sqlalchemy import select +from sqlalchemy.orm import Session, scoped_session from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService @@ -12,7 +13,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery @@ -56,7 +56,9 @@ class HitTestingService: } @classmethod - def _dump_retrieval_records(cls, records: list[RetrievalSegments]) -> list[dict[str, Any]]: + def _dump_retrieval_records( + cls, session: Session | scoped_session, records: list[RetrievalSegments] + ) -> list[dict[str, Any]]: document_ids = { document_id for record in records @@ -69,9 +71,7 @@ class HitTestingService: documents = { document.id: cls._dump_dataset_document(document) - for document in db.session.scalars( - select(DatasetDocument).where(DatasetDocument.id.in_(document_ids)) - ).all() + for document in session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all() } records_with_documents: list[dict[str, Any]] = [] @@ -105,6 +105,7 @@ class HitTestingService: @classmethod def retrieve( cls, + session: Session | scoped_session, dataset: Dataset, query: str, account: Account, @@ -142,7 +143,7 @@ class HitTestingService: if metadata_filter_document_ids: document_ids_filter = metadata_filter_document_ids.get(dataset.id, []) if metadata_condition and not document_ids_filter: - return cls.compact_retrieve_response(query, []) + return cls.compact_retrieve_response(session, query, []) all_documents = RetrievalService.retrieve( retrieval_method=RetrievalMethod( resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH) @@ -181,14 +182,15 @@ class HitTestingService: created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) - db.session.add(dataset_query) - db.session.commit() + session.add(dataset_query) + session.commit() - return cls.compact_retrieve_response(query, all_documents) + return cls.compact_retrieve_response(session, query, all_documents) @classmethod def external_retrieve( cls, + session: Session | scoped_session, dataset: Dataset, query: str, account: Account, @@ -222,20 +224,22 @@ class HitTestingService: created_by=account.id, ) - db.session.add(dataset_query) - db.session.commit() + session.add(dataset_query) + session.commit() return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict: + def compact_retrieve_response( + cls, session: Session | scoped_session, query: str, documents: list[Document] + ) -> RetrieveResponseDict: records = RetrievalService.format_retrieval_documents(documents) return { "query": { "content": query, }, - "records": cls._dump_retrieval_records(records), + "records": cls._dump_retrieval_records(session, records), } @classmethod diff --git a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 2d23ae8f68f..fbf993f7d69 100644 --- a/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -181,7 +181,9 @@ class TestHitTestingService: # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") - def test_compact_retrieve_response_should_format_correctly(self, mock_format: MagicMock) -> None: + def test_compact_retrieve_response_should_format_correctly( + self, mock_format: MagicMock, db_session_with_containers: Session + ) -> None: query = "test query" mock_doc = MagicMock(spec=Document) @@ -189,7 +191,9 @@ class TestHitTestingService: mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - response = _RetrieveResponse.model_validate(HitTestingService.compact_retrieve_response(query, [mock_doc])) + response = _RetrieveResponse.model_validate( + HitTestingService.compact_retrieve_response(db_session_with_containers, query, [mock_doc]) + ) assert response.query.content == query assert len(response.records) == 1 @@ -242,6 +246,7 @@ class TestHitTestingService: response = _RetrieveResponse.model_validate( HitTestingService.external_retrieve( + db_session_with_containers, dataset=dataset, query='test "query"', account=account, @@ -269,7 +274,9 @@ class TestHitTestingService: dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - response = _RetrieveResponse.model_validate(HitTestingService.external_retrieve(dataset, "test query", account)) + response = _RetrieveResponse.model_validate( + HitTestingService.external_retrieve(db_session_with_containers, dataset, "test query", account) + ) assert response.query.content == "test query" assert response.records == [] @@ -292,6 +299,7 @@ class TestHitTestingService: response = _RetrieveResponse.model_validate( HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -320,7 +328,11 @@ class TestHitTestingService: retrieval_model = { "search_method": "semantic_search", - "metadata_filtering_conditions": {"some": "condition"}, + "metadata_filtering_conditions": { + "conditions": [ + {"name": "category", "comparison_operator": "is", "value": "test"}, + ], + }, "top_k": 5, "reranking_enable": False, "score_threshold_enabled": False, @@ -330,6 +342,7 @@ class TestHitTestingService: mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -352,7 +365,11 @@ class TestHitTestingService: retrieval_model = { "search_method": "semantic_search", - "metadata_filtering_conditions": {"some": "condition"}, + "metadata_filtering_conditions": { + "conditions": [ + {"name": "category", "comparison_operator": "is", "value": "test"}, + ], + }, "top_k": 5, "reranking_enable": False, "score_threshold_enabled": False, @@ -362,6 +379,7 @@ class TestHitTestingService: response = _RetrieveResponse.model_validate( HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -393,6 +411,7 @@ class TestHitTestingService: mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -452,6 +471,7 @@ class TestHitTestingService: mock_retrieve.return_value = retrieved_documents HitTestingService.retrieve( + db_session_with_containers, dataset=dataset, query="test query", account=account, @@ -477,11 +497,15 @@ class TestHitTestingService: "doc_metadata": {"source": "manual"}, } - def test_dump_retrieval_records_returns_dumped_records_without_document_ids(self) -> None: + def test_dump_retrieval_records_returns_dumped_records_without_document_ids( + self, db_session_with_containers: Session + ) -> None: segment = _build_segment(document_id="") record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) - records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + records = _DUMPED_RETRIEVAL_RECORDS.validate_python( + HitTestingService._dump_retrieval_records(db_session_with_containers, [record]) + ) assert len(records) == 1 assert records[0].segment.id == segment.id @@ -493,7 +517,9 @@ class TestHitTestingService: segment = _create_segment(db_session_with_containers, document=document) record = RetrievalSegments.model_validate({"segment": segment, "score": 0.9}) - records = _DUMPED_RETRIEVAL_RECORDS.validate_python(HitTestingService._dump_retrieval_records([record])) + records = _DUMPED_RETRIEVAL_RECORDS.validate_python( + HitTestingService._dump_retrieval_records(db_session_with_containers, [record]) + ) assert len(records) == 1 dumped_segment = records[0].segment @@ -515,7 +541,7 @@ class TestHitTestingService: segment = _create_segment(db_session_with_containers) record = RetrievalSegments.model_validate({"segment": segment, "score": 0.95}) - result = HitTestingService._dump_retrieval_records([record]) + result = HitTestingService._dump_retrieval_records(db_session_with_containers, [record]) assert result == [] assert "Skipping hit-testing records with missing documents" in caplog.text diff --git a/api/tests/unit_tests/services/hit_service.py b/api/tests/unit_tests/services/hit_service.py index ddbc7dc0413..ae19daba898 100644 --- a/api/tests/unit_tests/services/hit_service.py +++ b/api/tests/unit_tests/services/hit_service.py @@ -147,8 +147,7 @@ class TestHitTestingServiceRetrieve: Provides a mocked database session for testing database operations like adding and committing DatasetQuery records. """ - with patch("services.hit_testing_service.db.session", autospec=True) as mock_db: - yield mock_db + return MagicMock() def test_retrieve_success_with_default_retrieval_model(self, mock_db_session): """ @@ -186,7 +185,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -232,7 +233,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -257,9 +260,11 @@ class TestHitTestingServiceRetrieve: retrieval_model = { "metadata_filtering_conditions": { "conditions": [ - {"field": "category", "operator": "is", "value": "test"}, + {"name": "category", "comparison_operator": "is", "value": "test"}, ], }, + "reranking_enable": False, + "score_threshold_enabled": False, } external_retrieval_model = {} @@ -286,7 +291,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -308,9 +315,11 @@ class TestHitTestingServiceRetrieve: retrieval_model = { "metadata_filtering_conditions": { "conditions": [ - {"field": "category", "operator": "is", "value": "test"}, + {"name": "category", "comparison_operator": "is", "value": "test"}, ], }, + "reranking_enable": False, + "score_threshold_enabled": False, } external_retrieval_model = {} @@ -327,7 +336,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = [] # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -344,6 +355,8 @@ class TestHitTestingServiceRetrieve: dataset_retrieval_model = { "search_method": RetrievalMethod.HYBRID_SEARCH, "top_k": 3, + "reranking_enable": False, + "score_threshold_enabled": False, } dataset = HitTestingTestDataFactory.create_dataset_mock(retrieval_model=dataset_retrieval_model) account = HitTestingTestDataFactory.create_user_mock() @@ -366,7 +379,9 @@ class TestHitTestingServiceRetrieve: mock_format.return_value = mock_records # Act - result = HitTestingService.retrieve(dataset, query, account, retrieval_model, external_retrieval_model) + result = HitTestingService.retrieve( + mock_db_session, dataset, query, account, retrieval_model, external_retrieval_model + ) # Assert assert result["query"]["content"] == query @@ -391,8 +406,7 @@ class TestHitTestingServiceExternalRetrieve: Provides a mocked database session for testing database operations like adding and committing DatasetQuery records. """ - with patch("services.hit_testing_service.db.session", autospec=True) as mock_db: - yield mock_db + return MagicMock() def test_external_retrieve_success(self, mock_db_session): """ @@ -424,7 +438,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -455,7 +469,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -490,7 +504,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -524,7 +538,7 @@ class TestHitTestingServiceExternalRetrieve: # Act result = HitTestingService.external_retrieve( - dataset, query, account, external_retrieval_model, metadata_filtering_conditions + mock_db_session, dataset, query, account, external_retrieval_model, metadata_filtering_conditions ) # Assert @@ -565,7 +579,7 @@ class TestHitTestingServiceCompactRetrieveResponse: mock_format.return_value = mock_records # Act - result = HitTestingService.compact_retrieve_response(query, documents) + result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents) # Assert assert result["query"]["content"] == query @@ -591,7 +605,7 @@ class TestHitTestingServiceCompactRetrieveResponse: mock_format.return_value = [] # Act - result = HitTestingService.compact_retrieve_response(query, documents) + result = HitTestingService.compact_retrieve_response(MagicMock(), query, documents) # Assert assert result["query"]["content"] == query @@ -708,7 +722,7 @@ class TestHitTestingServiceHitTestingArgsCheck: args = {"query": ""} # Act & Assert - with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"): + with pytest.raises(ValueError, match="Query or attachment_ids is required"): HitTestingService.hit_testing_args_check(args) def test_hit_testing_args_check_none_query(self): @@ -721,7 +735,7 @@ class TestHitTestingServiceHitTestingArgsCheck: args = {"query": None} # Act & Assert - with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"): + with pytest.raises(ValueError, match="Query or attachment_ids is required"): HitTestingService.hit_testing_args_check(args) def test_hit_testing_args_check_too_long_query(self): @@ -734,7 +748,7 @@ class TestHitTestingServiceHitTestingArgsCheck: args = {"query": "a" * 251} # Act & Assert - with pytest.raises(ValueError, match="Query is required and cannot exceed 250 characters"): + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): HitTestingService.hit_testing_args_check(args) def test_hit_testing_args_check_exactly_250_characters(self):