From eca0cdc7a993b4279c6a9afc79bfec7608ad6831 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Sat, 4 Apr 2026 19:13:06 -0500 Subject: [PATCH] =?UTF-8?q?refactor:=20select=20in=20dataset=5Fservice=20(?= =?UTF-8?q?SegmentService=20and=20remaining=20cla=E2=80=A6=20(#34547)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/services/dataset_service.py | 121 ++++++++---------- .../services/test_dataset_service_dataset.py | 15 ++- .../services/test_dataset_service_segment.py | 85 +++++------- 3 files changed, 92 insertions(+), 129 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index f7e22e0e89..0795fdb221 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError -from sqlalchemy import exists, func, select, update +from sqlalchemy import delete, exists, func, select, update from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound @@ -3152,10 +3152,8 @@ class SegmentService: lock_name = f"add_segment_lock_document_id_{document.id}" try: with redis_client.lock(lock_name, timeout=600): - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() + max_position = db.session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id) ) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, @@ -3207,7 +3205,7 @@ class SegmentService: segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() + segment = db.session.get(DocumentSegment, segment_document.id) return segment except LockNotOwnedError: pass @@ -3230,10 +3228,8 @@ class SegmentService: model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() + max_position = db.session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id) ) pre_segment_data_list = [] segment_data_list = [] @@ -3378,11 +3374,7 @@ class SegmentService: else: raise ValueError("The knowledge base index technique is not high quality!") # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) + processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id) if processing_rule: VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True @@ -3400,13 +3392,13 @@ class SegmentService: # Query existing summary from database from models.dataset import DocumentSegmentSummary - existing_summary = ( - db.session.query(DocumentSegmentSummary) + existing_summary = db.session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment.id, DocumentSegmentSummary.dataset_id == dataset.id, ) - .first() + .limit(1) ) # Check if summary has changed @@ -3482,11 +3474,7 @@ class SegmentService: else: raise ValueError("The knowledge base index technique is not high quality!") # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) + processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id) if processing_rule: VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True @@ -3498,13 +3486,13 @@ class SegmentService: if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary - existing_summary = ( - db.session.query(DocumentSegmentSummary) + existing_summary = db.session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment.id, DocumentSegmentSummary.dataset_id == dataset.id, ) - .first() + .limit(1) ) if args.summary is None: @@ -3570,7 +3558,7 @@ class SegmentService: segment.status = SegmentStatus.ERROR segment.error = str(e) db.session.commit() - new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + new_segment = db.session.get(DocumentSegment, segment.id) if not new_segment: raise ValueError("new_segment is not found") return new_segment @@ -3590,15 +3578,14 @@ class SegmentService: # Get child chunk IDs before parent segment is deleted child_node_ids = [] if segment.index_node_id: - child_chunks = ( - db.session.query(ChildChunk.index_node_id) - .where( - ChildChunk.segment_id == segment.id, - ChildChunk.dataset_id == dataset.id, - ) - .all() + child_node_ids = list( + db.session.scalars( + select(ChildChunk.index_node_id).where( + ChildChunk.segment_id == segment.id, + ChildChunk.dataset_id == dataset.id, + ) + ).all() ) - child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] delete_segment_from_index_task.delay( [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids @@ -3617,17 +3604,14 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - segments_info = ( - db.session.query(DocumentSegment) - .with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count) - .where( + segments_info = db.session.execute( + select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.tenant_id == current_user.current_tenant_id, ) - .all() - ) + ).all() if not segments_info: return @@ -3639,15 +3623,16 @@ class SegmentService: # Get child chunk IDs before parent segments are deleted child_node_ids = [] if index_node_ids: - child_chunks = ( - db.session.query(ChildChunk.index_node_id) - .where( - ChildChunk.segment_id.in_(segment_db_ids), - ChildChunk.dataset_id == dataset.id, - ) - .all() - ) - child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + child_node_ids = [ + nid + for nid in db.session.scalars( + select(ChildChunk.index_node_id).where( + ChildChunk.segment_id.in_(segment_db_ids), + ChildChunk.dataset_id == dataset.id, + ) + ).all() + if nid + ] # Start async cleanup with both parent and child node IDs if index_node_ids or child_node_ids: @@ -3663,7 +3648,7 @@ class SegmentService: db.session.add(document) # Delete database records - db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() + db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))) db.session.commit() @classmethod @@ -3737,15 +3722,13 @@ class SegmentService: with redis_client.lock(lock_name, timeout=20): index_node_id = str(uuid.uuid4()) index_node_hash = helper.generate_text_hash(content) - max_position = ( - db.session.query(func.max(ChildChunk.position)) - .where( + max_position = db.session.scalar( + select(func.max(ChildChunk.position)).where( ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, ChildChunk.segment_id == segment.id, ) - .scalar() ) child_chunk = ChildChunk( tenant_id=current_user.current_tenant_id, @@ -3905,10 +3888,8 @@ class SegmentService: @classmethod def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None: """Get a child chunk by its ID.""" - result = ( - db.session.query(ChildChunk) - .where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) - .first() + result = db.session.scalar( + select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1) ) return result if isinstance(result, ChildChunk) else None @@ -3943,10 +3924,10 @@ class SegmentService: @classmethod def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None: """Get a segment by its ID.""" - result = ( - db.session.query(DocumentSegment) + result = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) - .first() + .limit(1) ) return result if isinstance(result, DocumentSegment) else None @@ -3989,15 +3970,15 @@ class DatasetCollectionBindingService: def get_dataset_collection_binding( cls, provider_name: str, model_name: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) + dataset_collection_binding = db.session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == provider_name, DatasetCollectionBinding.model_name == model_name, DatasetCollectionBinding.type == collection_type, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -4015,13 +3996,13 @@ class DatasetCollectionBindingService: def get_dataset_collection_binding_by_id_and_type( cls, collection_binding_id: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) + dataset_collection_binding = db.session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: raise ValueError("Dataset collection binding not found") @@ -4043,7 +4024,7 @@ class DatasetPermissionService: @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): try: - db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() + db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) permissions = [] for user in user_list: permission = DatasetPermission( @@ -4079,7 +4060,7 @@ class DatasetPermissionService: @classmethod def clear_partial_member_list(cls, dataset_id): try: - db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() + db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) db.session.commit() except Exception as e: db.session.rollback() diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py index 849229ff43..64741eb5bb 100644 --- a/api/tests/unit_tests/services/test_dataset_service_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -1607,7 +1607,7 @@ class TestDatasetCollectionBindingService: binding = SimpleNamespace(id="binding-1") with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + mock_db.session.scalar.return_value = binding result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") @@ -1619,10 +1619,11 @@ class TestDatasetCollectionBindingService: with ( patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.select"), patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls, patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"), ): - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset") @@ -1638,7 +1639,7 @@ class TestDatasetCollectionBindingService: def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self): with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") @@ -1647,7 +1648,7 @@ class TestDatasetCollectionBindingService: binding = SimpleNamespace(id="binding-1") with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + mock_db.session.scalar.return_value = binding result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") @@ -1673,7 +1674,7 @@ class TestDatasetPermissionService: [{"user_id": "user-1"}, {"user_id": "user-2"}], ) - mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.execute.assert_called() mock_db.session.add_all.assert_called_once() mock_db.session.commit.assert_called_once() @@ -1744,12 +1745,12 @@ class TestDatasetPermissionService: with patch("services.dataset_service.db") as mock_db: DatasetPermissionService.clear_partial_member_list("dataset-1") - mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.execute.assert_called() mock_db.session.commit.assert_called_once() def test_clear_partial_member_list_rolls_back_on_exception(self): with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom") + mock_db.session.execute.side_effect = RuntimeError("boom") with pytest.raises(RuntimeError, match="boom"): DatasetPermissionService.clear_partial_member_list("dataset-1") diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py index 2f8ae14a8e..d6c104708c 100644 --- a/api/tests/unit_tests/services/test_dataset_service_segment.py +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -49,7 +49,7 @@ class TestSegmentServiceChildChunks: patch("services.dataset_service.VectorService") as vector_service, ): mock_redis.lock.return_value = _make_lock_context() - mock_db.session.query.return_value.where.return_value.scalar.return_value = 2 + mock_db.session.scalar.return_value = 2 child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset) @@ -75,7 +75,7 @@ class TestSegmentServiceChildChunks: patch("services.dataset_service.VectorService") as vector_service, ): mock_redis.lock.return_value = _make_lock_context() - mock_db.session.query.return_value.where.return_value.scalar.return_value = None + mock_db.session.scalar.return_value = None vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed") with pytest.raises(ChildChunkIndexingError, match="vector failed"): @@ -247,13 +247,13 @@ class TestSegmentServiceQueries: child_chunk = _make_child_chunk() with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk + mock_db.session.scalar.return_value = child_chunk result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") assert result is child_chunk with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + mock_db.session.scalar.return_value = SimpleNamespace() result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") assert result is None @@ -295,13 +295,13 @@ class TestSegmentServiceQueries: ) with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = segment + mock_db.session.scalar.return_value = segment result = SegmentService.get_segment_by_id("segment-1", "tenant-1") assert result is segment with patch("services.dataset_service.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + mock_db.session.scalar.return_value = SimpleNamespace() result = SegmentService.get_segment_by_id("segment-1", "tenant-1") assert result is None @@ -401,11 +401,8 @@ class TestSegmentServiceMutations: ): mock_redis.lock.return_value = _make_lock_context() - max_position_query = MagicMock() - max_position_query.where.return_value.scalar.return_value = 2 - refresh_query = MagicMock() - refresh_query.where.return_value.first.return_value = refreshed_segment - mock_db.session.query.side_effect = [max_position_query, refresh_query] + mock_db.session.scalar.return_value = 2 + mock_db.session.get.return_value = refreshed_segment def add_side_effect(obj): if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None: @@ -461,7 +458,7 @@ class TestSegmentServiceMutations: ): mock_redis.lock.return_value = _make_lock_context() model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model - mock_db.session.query.return_value.where.return_value.scalar.return_value = 1 + mock_db.session.scalar.return_value = 1 vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") result = SegmentService.multi_create_segment(segments, document, dataset) @@ -538,7 +535,7 @@ class TestSegmentServiceMutations: patch("services.dataset_service.VectorService") as vector_service, ): mock_redis.get.return_value = None - mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + mock_db.session.get.return_value = refreshed_segment result = SegmentService.update_segment(args, segment, document, dataset) @@ -574,13 +571,10 @@ class TestSegmentServiceMutations: mock_redis.get.return_value = None model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance - processing_rule_query = MagicMock() - processing_rule_query.where.return_value.first.return_value = processing_rule - summary_query = MagicMock() - summary_query.where.return_value.first.return_value = existing_summary - refreshed_query = MagicMock() - refreshed_query.where.return_value.first.return_value = refreshed_segment - mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + # get calls: processing_rule, then refreshed_segment + mock_db.session.get.side_effect = [processing_rule, refreshed_segment] + # scalar call: existing_summary + mock_db.session.scalar.return_value = existing_summary result = SegmentService.update_segment(args, segment, document, dataset) @@ -621,11 +615,8 @@ class TestSegmentServiceMutations: mock_redis.get.return_value = None model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model - summary_query = MagicMock() - summary_query.where.return_value.first.return_value = existing_summary - refreshed_query = MagicMock() - refreshed_query.where.return_value.first.return_value = refreshed_segment - mock_db.session.query.side_effect = [summary_query, refreshed_query] + mock_db.session.scalar.return_value = existing_summary + mock_db.session.get.return_value = refreshed_segment result = SegmentService.update_segment(args, segment, document, dataset) @@ -664,11 +655,8 @@ class TestSegmentServiceMutations: mock_redis.get.return_value = None model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model - summary_query = MagicMock() - summary_query.where.return_value.first.return_value = existing_summary - refreshed_query = MagicMock() - refreshed_query.where.return_value.first.return_value = refreshed_segment - mock_db.session.query.side_effect = [summary_query, refreshed_query] + mock_db.session.scalar.return_value = existing_summary + mock_db.session.get.return_value = refreshed_segment result = SegmentService.update_segment(args, segment, document, dataset) @@ -688,7 +676,7 @@ class TestSegmentServiceMutations: patch("services.dataset_service.delete_segment_from_index_task") as delete_task, ): mock_redis.get.return_value = None - mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)] + mock_db.session.scalars.return_value.all.return_value = ["child-1", "child-2"] SegmentService.delete_segment(segment, document, dataset) @@ -727,15 +715,15 @@ class TestSegmentServiceMutations: patch("services.dataset_service.delete_segment_from_index_task") as delete_task, ): segments_query = MagicMock() - segments_query.with_entities.return_value.where.return_value.all.return_value = [ + # execute().all() for segments_info (multi-column) + execute_result = MagicMock() + execute_result.all.return_value = [ ("node-1", "segment-1", 2), ("node-2", "segment-2", 5), ] - child_query = MagicMock() - child_query.where.return_value.all.return_value = [("child-1",)] - delete_query = MagicMock() - delete_query.where.return_value.delete.return_value = 2 - mock_db.session.query.side_effect = [segments_query, child_query, delete_query] + mock_db.session.execute.return_value = execute_result + # scalars() for child_node_ids + mock_db.session.scalars.return_value.all.return_value = ["child-1"] SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset) @@ -748,7 +736,6 @@ class TestSegmentServiceMutations: ["segment-1", "segment-2"], ["child-1"], ) - delete_query.where.return_value.delete.assert_called_once() mock_db.session.commit.assert_called_once() def test_update_segments_status_enables_only_segments_without_indexing_cache(self): @@ -868,7 +855,7 @@ class TestSegmentServiceAdditionalRegenerationBranches: patch("services.dataset_service.VectorService") as vector_service, ): mock_redis.get.return_value = None - mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + mock_db.session.get.return_value = refreshed_segment result = SegmentService.update_segment( SegmentUpdateArgs(content="question", answer="new answer"), @@ -902,11 +889,8 @@ class TestSegmentServiceAdditionalRegenerationBranches: ): mock_redis.get.return_value = None model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model - summary_query = MagicMock() - summary_query.where.return_value.first.return_value = None - refreshed_query = MagicMock() - refreshed_query.where.return_value.first.return_value = refreshed_segment - mock_db.session.query.side_effect = [summary_query, refreshed_query] + mock_db.session.scalar.return_value = None + mock_db.session.get.return_value = refreshed_segment result = SegmentService.update_segment( SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]), @@ -951,13 +935,10 @@ class TestSegmentServiceAdditionalRegenerationBranches: model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance update_summary.side_effect = RuntimeError("summary failed") - processing_rule_query = MagicMock() - processing_rule_query.where.return_value.first.return_value = processing_rule - summary_query = MagicMock() - summary_query.where.return_value.first.return_value = existing_summary - refreshed_query = MagicMock() - refreshed_query.where.return_value.first.return_value = refreshed_segment - mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + # get calls: processing_rule, then refreshed_segment + mock_db.session.get.side_effect = [processing_rule, refreshed_segment] + # scalar call: existing_summary + mock_db.session.scalar.return_value = existing_summary result = SegmentService.update_segment( SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"), @@ -1000,7 +981,7 @@ class TestSegmentServiceAdditionalRegenerationBranches: patch("services.dataset_service.VectorService") as vector_service, ): mock_redis.get.return_value = None - mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + mock_db.session.get.return_value = refreshed_segment result = SegmentService.update_segment( SegmentUpdateArgs(content="same content", regenerate_child_chunks=True),