refactor: select in dataset_service (SegmentService and remaining cla… (#34547)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-04-04 19:13:06 -05:00 committed by GitHub
parent 779e6b8e0b
commit eca0cdc7a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 129 deletions

View File

@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select, update
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound
@ -3152,10 +3152,8 @@ class SegmentService:
lock_name = f"add_segment_lock_document_id_{document.id}"
try:
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
max_position = db.session.scalar(
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
)
segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id,
@ -3207,7 +3205,7 @@ class SegmentService:
segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
segment = db.session.get(DocumentSegment, segment_document.id)
return segment
except LockNotOwnedError:
pass
@ -3230,10 +3228,8 @@ class SegmentService:
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document.id)
.scalar()
max_position = db.session.scalar(
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
)
pre_segment_data_list = []
segment_data_list = []
@ -3378,11 +3374,7 @@ class SegmentService:
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
@ -3400,13 +3392,13 @@ class SegmentService:
# Query existing summary from database
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
existing_summary = db.session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
.limit(1)
)
# Check if summary has changed
@ -3482,11 +3474,7 @@ class SegmentService:
else:
raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
if processing_rule:
VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True
@ -3498,13 +3486,13 @@ class SegmentService:
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
existing_summary = db.session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
.limit(1)
)
if args.summary is None:
@ -3570,7 +3558,7 @@ class SegmentService:
segment.status = SegmentStatus.ERROR
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
new_segment = db.session.get(DocumentSegment, segment.id)
if not new_segment:
raise ValueError("new_segment is not found")
return new_segment
@ -3590,15 +3578,14 @@ class SegmentService:
# Get child chunk IDs before parent segment is deleted
child_node_ids = []
if segment.index_node_id:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id,
)
.all()
child_node_ids = list(
db.session.scalars(
select(ChildChunk.index_node_id).where(
ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id,
)
).all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
@ -3617,17 +3604,14 @@ class SegmentService:
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
segments_info = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
.where(
segments_info = db.session.execute(
select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.all()
)
).all()
if not segments_info:
return
@ -3639,15 +3623,16 @@ class SegmentService:
# Get child chunk IDs before parent segments are deleted
child_node_ids = []
if index_node_ids:
child_chunks = (
db.session.query(ChildChunk.index_node_id)
.where(
ChildChunk.segment_id.in_(segment_db_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
child_node_ids = [
nid
for nid in db.session.scalars(
select(ChildChunk.index_node_id).where(
ChildChunk.segment_id.in_(segment_db_ids),
ChildChunk.dataset_id == dataset.id,
)
).all()
if nid
]
# Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids:
@ -3663,7 +3648,7 @@ class SegmentService:
db.session.add(document)
# Delete database records
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)))
db.session.commit()
@classmethod
@ -3737,15 +3722,13 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(content)
max_position = (
db.session.query(func.max(ChildChunk.position))
.where(
max_position = db.session.scalar(
select(func.max(ChildChunk.position)).where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.scalar()
)
child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id,
@ -3905,10 +3888,8 @@ class SegmentService:
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
.where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
result = db.session.scalar(
select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1)
)
return result if isinstance(result, ChildChunk) else None
@ -3943,10 +3924,10 @@ class SegmentService:
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
result = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
.limit(1)
)
return result if isinstance(result, DocumentSegment) else None
@ -3989,15 +3970,15 @@ class DatasetCollectionBindingService:
def get_dataset_collection_binding(
cls, provider_name: str, model_name: str, collection_type: str = "dataset"
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
dataset_collection_binding = db.session.scalar(
select(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type,
)
.order_by(DatasetCollectionBinding.created_at)
.first()
.limit(1)
)
if not dataset_collection_binding:
@ -4015,13 +3996,13 @@ class DatasetCollectionBindingService:
def get_dataset_collection_binding_by_id_and_type(
cls, collection_binding_id: str, collection_type: str = "dataset"
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
dataset_collection_binding = db.session.scalar(
select(DatasetCollectionBinding)
.where(
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
)
.order_by(DatasetCollectionBinding.created_at)
.first()
.limit(1)
)
if not dataset_collection_binding:
raise ValueError("Dataset collection binding not found")
@ -4043,7 +4024,7 @@ class DatasetPermissionService:
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
try:
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
permissions = []
for user in user_list:
permission = DatasetPermission(
@ -4079,7 +4060,7 @@ class DatasetPermissionService:
@classmethod
def clear_partial_member_list(cls, dataset_id):
try:
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
db.session.commit()
except Exception as e:
db.session.rollback()

View File

@ -1607,7 +1607,7 @@ class TestDatasetCollectionBindingService:
binding = SimpleNamespace(id="binding-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
mock_db.session.scalar.return_value = binding
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
@ -1619,10 +1619,11 @@ class TestDatasetCollectionBindingService:
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls,
patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"),
):
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset")
@ -1638,7 +1639,7 @@ class TestDatasetCollectionBindingService:
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
@ -1647,7 +1648,7 @@ class TestDatasetCollectionBindingService:
binding = SimpleNamespace(id="binding-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
mock_db.session.scalar.return_value = binding
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
@ -1673,7 +1674,7 @@ class TestDatasetPermissionService:
[{"user_id": "user-1"}, {"user_id": "user-2"}],
)
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
mock_db.session.execute.assert_called()
mock_db.session.add_all.assert_called_once()
mock_db.session.commit.assert_called_once()
@ -1744,12 +1745,12 @@ class TestDatasetPermissionService:
with patch("services.dataset_service.db") as mock_db:
DatasetPermissionService.clear_partial_member_list("dataset-1")
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
mock_db.session.execute.assert_called()
mock_db.session.commit.assert_called_once()
def test_clear_partial_member_list_rolls_back_on_exception(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom")
mock_db.session.execute.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
DatasetPermissionService.clear_partial_member_list("dataset-1")

View File

@ -49,7 +49,7 @@ class TestSegmentServiceChildChunks:
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.lock.return_value = _make_lock_context()
mock_db.session.query.return_value.where.return_value.scalar.return_value = 2
mock_db.session.scalar.return_value = 2
child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset)
@ -75,7 +75,7 @@ class TestSegmentServiceChildChunks:
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.lock.return_value = _make_lock_context()
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
mock_db.session.scalar.return_value = None
vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed")
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
@ -247,13 +247,13 @@ class TestSegmentServiceQueries:
child_chunk = _make_child_chunk()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk
mock_db.session.scalar.return_value = child_chunk
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
assert result is child_chunk
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
mock_db.session.scalar.return_value = SimpleNamespace()
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
assert result is None
@ -295,13 +295,13 @@ class TestSegmentServiceQueries:
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = segment
mock_db.session.scalar.return_value = segment
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
assert result is segment
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
mock_db.session.scalar.return_value = SimpleNamespace()
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
assert result is None
@ -401,11 +401,8 @@ class TestSegmentServiceMutations:
):
mock_redis.lock.return_value = _make_lock_context()
max_position_query = MagicMock()
max_position_query.where.return_value.scalar.return_value = 2
refresh_query = MagicMock()
refresh_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [max_position_query, refresh_query]
mock_db.session.scalar.return_value = 2
mock_db.session.get.return_value = refreshed_segment
def add_side_effect(obj):
if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None:
@ -461,7 +458,7 @@ class TestSegmentServiceMutations:
):
mock_redis.lock.return_value = _make_lock_context()
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
mock_db.session.scalar.return_value = 1
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
result = SegmentService.multi_create_segment(segments, document, dataset)
@ -538,7 +535,7 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment(args, segment, document, dataset)
@ -574,13 +571,10 @@ class TestSegmentServiceMutations:
mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance
processing_rule_query = MagicMock()
processing_rule_query.where.return_value.first.return_value = processing_rule
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
# get calls: processing_rule, then refreshed_segment
mock_db.session.get.side_effect = [processing_rule, refreshed_segment]
# scalar call: existing_summary
mock_db.session.scalar.return_value = existing_summary
result = SegmentService.update_segment(args, segment, document, dataset)
@ -621,11 +615,8 @@ class TestSegmentServiceMutations:
mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
mock_db.session.scalar.return_value = existing_summary
mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment(args, segment, document, dataset)
@ -664,11 +655,8 @@ class TestSegmentServiceMutations:
mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
mock_db.session.scalar.return_value = existing_summary
mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment(args, segment, document, dataset)
@ -688,7 +676,7 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)]
mock_db.session.scalars.return_value.all.return_value = ["child-1", "child-2"]
SegmentService.delete_segment(segment, document, dataset)
@ -727,15 +715,15 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
):
segments_query = MagicMock()
segments_query.with_entities.return_value.where.return_value.all.return_value = [
# execute().all() for segments_info (multi-column)
execute_result = MagicMock()
execute_result.all.return_value = [
("node-1", "segment-1", 2),
("node-2", "segment-2", 5),
]
child_query = MagicMock()
child_query.where.return_value.all.return_value = [("child-1",)]
delete_query = MagicMock()
delete_query.where.return_value.delete.return_value = 2
mock_db.session.query.side_effect = [segments_query, child_query, delete_query]
mock_db.session.execute.return_value = execute_result
# scalars() for child_node_ids
mock_db.session.scalars.return_value.all.return_value = ["child-1"]
SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset)
@ -748,7 +736,6 @@ class TestSegmentServiceMutations:
["segment-1", "segment-2"],
["child-1"],
)
delete_query.where.return_value.delete.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_update_segments_status_enables_only_segments_without_indexing_cache(self):
@ -868,7 +855,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment(
SegmentUpdateArgs(content="question", answer="new answer"),
@ -902,11 +889,8 @@ class TestSegmentServiceAdditionalRegenerationBranches:
):
mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = None
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
mock_db.session.scalar.return_value = None
mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment(
SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]),
@ -951,13 +935,10 @@ class TestSegmentServiceAdditionalRegenerationBranches:
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance
update_summary.side_effect = RuntimeError("summary failed")
processing_rule_query = MagicMock()
processing_rule_query.where.return_value.first.return_value = processing_rule
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
# get calls: processing_rule, then refreshed_segment
mock_db.session.get.side_effect = [processing_rule, refreshed_segment]
# scalar call: existing_summary
mock_db.session.scalar.return_value = existing_summary
result = SegmentService.update_segment(
SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"),
@ -1000,7 +981,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.dataset_service.VectorService") as vector_service,
):
mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment(
SegmentUpdateArgs(content="same content", regenerate_child_chunks=True),