refactor: select in dataset_service (SegmentService and remaining cla… (#34547)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-04-04 19:13:06 -05:00 committed by GitHub
parent 779e6b8e0b
commit eca0cdc7a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 129 deletions

View File

@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from redis.exceptions import LockNotOwnedError from redis.exceptions import LockNotOwnedError
from sqlalchemy import exists, func, select, update from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -3152,10 +3152,8 @@ class SegmentService:
lock_name = f"add_segment_lock_document_id_{document.id}" lock_name = f"add_segment_lock_document_id_{document.id}"
try: try:
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
max_position = ( max_position = db.session.scalar(
db.session.query(func.max(DocumentSegment.position)) select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
.where(DocumentSegment.document_id == document.id)
.scalar()
) )
segment_document = DocumentSegment( segment_document = DocumentSegment(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -3207,7 +3205,7 @@ class SegmentService:
segment_document.status = SegmentStatus.ERROR segment_document.status = SegmentStatus.ERROR
segment_document.error = str(e) segment_document.error = str(e)
db.session.commit() db.session.commit()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() segment = db.session.get(DocumentSegment, segment_document.id)
return segment return segment
except LockNotOwnedError: except LockNotOwnedError:
pass pass
@ -3230,10 +3228,8 @@ class SegmentService:
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
) )
max_position = ( max_position = db.session.scalar(
db.session.query(func.max(DocumentSegment.position)) select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
.where(DocumentSegment.document_id == document.id)
.scalar()
) )
pre_segment_data_list = [] pre_segment_data_list = []
segment_data_list = [] segment_data_list = []
@ -3378,11 +3374,7 @@ class SegmentService:
else: else:
raise ValueError("The knowledge base index technique is not high quality!") raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule # get the process rule
processing_rule = ( processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if processing_rule: if processing_rule:
VectorService.generate_child_chunks( VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True segment, document, dataset, embedding_model_instance, processing_rule, True
@ -3400,13 +3392,13 @@ class SegmentService:
# Query existing summary from database # Query existing summary from database
from models.dataset import DocumentSegmentSummary from models.dataset import DocumentSegmentSummary
existing_summary = ( existing_summary = db.session.scalar(
db.session.query(DocumentSegmentSummary) select(DocumentSegmentSummary)
.where( .where(
DocumentSegmentSummary.chunk_id == segment.id, DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id, DocumentSegmentSummary.dataset_id == dataset.id,
) )
.first() .limit(1)
) )
# Check if summary has changed # Check if summary has changed
@ -3482,11 +3474,7 @@ class SegmentService:
else: else:
raise ValueError("The knowledge base index technique is not high quality!") raise ValueError("The knowledge base index technique is not high quality!")
# get the process rule # get the process rule
processing_rule = ( processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if processing_rule: if processing_rule:
VectorService.generate_child_chunks( VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True segment, document, dataset, embedding_model_instance, processing_rule, True
@ -3498,13 +3486,13 @@ class SegmentService:
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
from models.dataset import DocumentSegmentSummary from models.dataset import DocumentSegmentSummary
existing_summary = ( existing_summary = db.session.scalar(
db.session.query(DocumentSegmentSummary) select(DocumentSegmentSummary)
.where( .where(
DocumentSegmentSummary.chunk_id == segment.id, DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id, DocumentSegmentSummary.dataset_id == dataset.id,
) )
.first() .limit(1)
) )
if args.summary is None: if args.summary is None:
@ -3570,7 +3558,7 @@ class SegmentService:
segment.status = SegmentStatus.ERROR segment.status = SegmentStatus.ERROR
segment.error = str(e) segment.error = str(e)
db.session.commit() db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() new_segment = db.session.get(DocumentSegment, segment.id)
if not new_segment: if not new_segment:
raise ValueError("new_segment is not found") raise ValueError("new_segment is not found")
return new_segment return new_segment
@ -3590,15 +3578,14 @@ class SegmentService:
# Get child chunk IDs before parent segment is deleted # Get child chunk IDs before parent segment is deleted
child_node_ids = [] child_node_ids = []
if segment.index_node_id: if segment.index_node_id:
child_chunks = ( child_node_ids = list(
db.session.query(ChildChunk.index_node_id) db.session.scalars(
.where( select(ChildChunk.index_node_id).where(
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
) )
.all() ).all()
) )
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
delete_segment_from_index_task.delay( delete_segment_from_index_task.delay(
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
@ -3617,17 +3604,14 @@ class SegmentService:
# Check if segment_ids is not empty to avoid WHERE false condition # Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0: if not segment_ids or len(segment_ids) == 0:
return return
segments_info = ( segments_info = db.session.execute(
db.session.query(DocumentSegment) select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where(
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
.where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id, DocumentSegment.document_id == document.id,
DocumentSegment.tenant_id == current_user.current_tenant_id, DocumentSegment.tenant_id == current_user.current_tenant_id,
) )
.all() ).all()
)
if not segments_info: if not segments_info:
return return
@ -3639,15 +3623,16 @@ class SegmentService:
# Get child chunk IDs before parent segments are deleted # Get child chunk IDs before parent segments are deleted
child_node_ids = [] child_node_ids = []
if index_node_ids: if index_node_ids:
child_chunks = ( child_node_ids = [
db.session.query(ChildChunk.index_node_id) nid
.where( for nid in db.session.scalars(
ChildChunk.segment_id.in_(segment_db_ids), select(ChildChunk.index_node_id).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.segment_id.in_(segment_db_ids),
) ChildChunk.dataset_id == dataset.id,
.all() )
) ).all()
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] if nid
]
# Start async cleanup with both parent and child node IDs # Start async cleanup with both parent and child node IDs
if index_node_ids or child_node_ids: if index_node_ids or child_node_ids:
@ -3663,7 +3648,7 @@ class SegmentService:
db.session.add(document) db.session.add(document)
# Delete database records # Delete database records
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)))
db.session.commit() db.session.commit()
@classmethod @classmethod
@ -3737,15 +3722,13 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock_name, timeout=20):
index_node_id = str(uuid.uuid4()) index_node_id = str(uuid.uuid4())
index_node_hash = helper.generate_text_hash(content) index_node_hash = helper.generate_text_hash(content)
max_position = ( max_position = db.session.scalar(
db.session.query(func.max(ChildChunk.position)) select(func.max(ChildChunk.position)).where(
.where(
ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id, ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
) )
.scalar()
) )
child_chunk = ChildChunk( child_chunk = ChildChunk(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -3905,10 +3888,8 @@ class SegmentService:
@classmethod @classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None: def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
"""Get a child chunk by its ID.""" """Get a child chunk by its ID."""
result = ( result = db.session.scalar(
db.session.query(ChildChunk) select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1)
.where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
) )
return result if isinstance(result, ChildChunk) else None return result if isinstance(result, ChildChunk) else None
@ -3943,10 +3924,10 @@ class SegmentService:
@classmethod @classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None: def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
"""Get a segment by its ID.""" """Get a segment by its ID."""
result = ( result = db.session.scalar(
db.session.query(DocumentSegment) select(DocumentSegment)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first() .limit(1)
) )
return result if isinstance(result, DocumentSegment) else None return result if isinstance(result, DocumentSegment) else None
@ -3989,15 +3970,15 @@ class DatasetCollectionBindingService:
def get_dataset_collection_binding( def get_dataset_collection_binding(
cls, provider_name: str, model_name: str, collection_type: str = "dataset" cls, provider_name: str, model_name: str, collection_type: str = "dataset"
) -> DatasetCollectionBinding: ) -> DatasetCollectionBinding:
dataset_collection_binding = ( dataset_collection_binding = db.session.scalar(
db.session.query(DatasetCollectionBinding) select(DatasetCollectionBinding)
.where( .where(
DatasetCollectionBinding.provider_name == provider_name, DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name, DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type, DatasetCollectionBinding.type == collection_type,
) )
.order_by(DatasetCollectionBinding.created_at) .order_by(DatasetCollectionBinding.created_at)
.first() .limit(1)
) )
if not dataset_collection_binding: if not dataset_collection_binding:
@ -4015,13 +3996,13 @@ class DatasetCollectionBindingService:
def get_dataset_collection_binding_by_id_and_type( def get_dataset_collection_binding_by_id_and_type(
cls, collection_binding_id: str, collection_type: str = "dataset" cls, collection_binding_id: str, collection_type: str = "dataset"
) -> DatasetCollectionBinding: ) -> DatasetCollectionBinding:
dataset_collection_binding = ( dataset_collection_binding = db.session.scalar(
db.session.query(DatasetCollectionBinding) select(DatasetCollectionBinding)
.where( .where(
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
) )
.order_by(DatasetCollectionBinding.created_at) .order_by(DatasetCollectionBinding.created_at)
.first() .limit(1)
) )
if not dataset_collection_binding: if not dataset_collection_binding:
raise ValueError("Dataset collection binding not found") raise ValueError("Dataset collection binding not found")
@ -4043,7 +4024,7 @@ class DatasetPermissionService:
@classmethod @classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list): def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
try: try:
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
permissions = [] permissions = []
for user in user_list: for user in user_list:
permission = DatasetPermission( permission = DatasetPermission(
@ -4079,7 +4060,7 @@ class DatasetPermissionService:
@classmethod @classmethod
def clear_partial_member_list(cls, dataset_id): def clear_partial_member_list(cls, dataset_id):
try: try:
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()

View File

@ -1607,7 +1607,7 @@ class TestDatasetCollectionBindingService:
binding = SimpleNamespace(id="binding-1") binding = SimpleNamespace(id="binding-1")
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding mock_db.session.scalar.return_value = binding
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
@ -1619,10 +1619,11 @@ class TestDatasetCollectionBindingService:
with ( with (
patch("services.dataset_service.db") as mock_db, patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select"),
patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls, patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls,
patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"), patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"),
): ):
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None mock_db.session.scalar.return_value = None
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset") result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset")
@ -1638,7 +1639,7 @@ class TestDatasetCollectionBindingService:
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self): def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self):
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match="Dataset collection binding not found"): with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
@ -1647,7 +1648,7 @@ class TestDatasetCollectionBindingService:
binding = SimpleNamespace(id="binding-1") binding = SimpleNamespace(id="binding-1")
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding mock_db.session.scalar.return_value = binding
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
@ -1673,7 +1674,7 @@ class TestDatasetPermissionService:
[{"user_id": "user-1"}, {"user_id": "user-2"}], [{"user_id": "user-1"}, {"user_id": "user-2"}],
) )
mock_db.session.query.return_value.where.return_value.delete.assert_called_once() mock_db.session.execute.assert_called()
mock_db.session.add_all.assert_called_once() mock_db.session.add_all.assert_called_once()
mock_db.session.commit.assert_called_once() mock_db.session.commit.assert_called_once()
@ -1744,12 +1745,12 @@ class TestDatasetPermissionService:
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
DatasetPermissionService.clear_partial_member_list("dataset-1") DatasetPermissionService.clear_partial_member_list("dataset-1")
mock_db.session.query.return_value.where.return_value.delete.assert_called_once() mock_db.session.execute.assert_called()
mock_db.session.commit.assert_called_once() mock_db.session.commit.assert_called_once()
def test_clear_partial_member_list_rolls_back_on_exception(self): def test_clear_partial_member_list_rolls_back_on_exception(self):
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom") mock_db.session.execute.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"): with pytest.raises(RuntimeError, match="boom"):
DatasetPermissionService.clear_partial_member_list("dataset-1") DatasetPermissionService.clear_partial_member_list("dataset-1")

View File

@ -49,7 +49,7 @@ class TestSegmentServiceChildChunks:
patch("services.dataset_service.VectorService") as vector_service, patch("services.dataset_service.VectorService") as vector_service,
): ):
mock_redis.lock.return_value = _make_lock_context() mock_redis.lock.return_value = _make_lock_context()
mock_db.session.query.return_value.where.return_value.scalar.return_value = 2 mock_db.session.scalar.return_value = 2
child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset) child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset)
@ -75,7 +75,7 @@ class TestSegmentServiceChildChunks:
patch("services.dataset_service.VectorService") as vector_service, patch("services.dataset_service.VectorService") as vector_service,
): ):
mock_redis.lock.return_value = _make_lock_context() mock_redis.lock.return_value = _make_lock_context()
mock_db.session.query.return_value.where.return_value.scalar.return_value = None mock_db.session.scalar.return_value = None
vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed") vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed")
with pytest.raises(ChildChunkIndexingError, match="vector failed"): with pytest.raises(ChildChunkIndexingError, match="vector failed"):
@ -247,13 +247,13 @@ class TestSegmentServiceQueries:
child_chunk = _make_child_chunk() child_chunk = _make_child_chunk()
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk mock_db.session.scalar.return_value = child_chunk
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
assert result is child_chunk assert result is child_chunk
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() mock_db.session.scalar.return_value = SimpleNamespace()
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
assert result is None assert result is None
@ -295,13 +295,13 @@ class TestSegmentServiceQueries:
) )
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = segment mock_db.session.scalar.return_value = segment
result = SegmentService.get_segment_by_id("segment-1", "tenant-1") result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
assert result is segment assert result is segment
with patch("services.dataset_service.db") as mock_db: with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() mock_db.session.scalar.return_value = SimpleNamespace()
result = SegmentService.get_segment_by_id("segment-1", "tenant-1") result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
assert result is None assert result is None
@ -401,11 +401,8 @@ class TestSegmentServiceMutations:
): ):
mock_redis.lock.return_value = _make_lock_context() mock_redis.lock.return_value = _make_lock_context()
max_position_query = MagicMock() mock_db.session.scalar.return_value = 2
max_position_query.where.return_value.scalar.return_value = 2 mock_db.session.get.return_value = refreshed_segment
refresh_query = MagicMock()
refresh_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [max_position_query, refresh_query]
def add_side_effect(obj): def add_side_effect(obj):
if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None: if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None:
@ -461,7 +458,7 @@ class TestSegmentServiceMutations:
): ):
mock_redis.lock.return_value = _make_lock_context() mock_redis.lock.return_value = _make_lock_context()
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1 mock_db.session.scalar.return_value = 1
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
result = SegmentService.multi_create_segment(segments, document, dataset) result = SegmentService.multi_create_segment(segments, document, dataset)
@ -538,7 +535,7 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.VectorService") as vector_service, patch("services.dataset_service.VectorService") as vector_service,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment(args, segment, document, dataset) result = SegmentService.update_segment(args, segment, document, dataset)
@ -574,13 +571,10 @@ class TestSegmentServiceMutations:
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance
processing_rule_query = MagicMock() # get calls: processing_rule, then refreshed_segment
processing_rule_query.where.return_value.first.return_value = processing_rule mock_db.session.get.side_effect = [processing_rule, refreshed_segment]
summary_query = MagicMock() # scalar call: existing_summary
summary_query.where.return_value.first.return_value = existing_summary mock_db.session.scalar.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
result = SegmentService.update_segment(args, segment, document, dataset) result = SegmentService.update_segment(args, segment, document, dataset)
@ -621,11 +615,8 @@ class TestSegmentServiceMutations:
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock() mock_db.session.scalar.return_value = existing_summary
summary_query.where.return_value.first.return_value = existing_summary mock_db.session.get.return_value = refreshed_segment
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
result = SegmentService.update_segment(args, segment, document, dataset) result = SegmentService.update_segment(args, segment, document, dataset)
@ -664,11 +655,8 @@ class TestSegmentServiceMutations:
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock() mock_db.session.scalar.return_value = existing_summary
summary_query.where.return_value.first.return_value = existing_summary mock_db.session.get.return_value = refreshed_segment
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
result = SegmentService.update_segment(args, segment, document, dataset) result = SegmentService.update_segment(args, segment, document, dataset)
@ -688,7 +676,7 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.delete_segment_from_index_task") as delete_task, patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)] mock_db.session.scalars.return_value.all.return_value = ["child-1", "child-2"]
SegmentService.delete_segment(segment, document, dataset) SegmentService.delete_segment(segment, document, dataset)
@ -727,15 +715,15 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.delete_segment_from_index_task") as delete_task, patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
): ):
segments_query = MagicMock() segments_query = MagicMock()
segments_query.with_entities.return_value.where.return_value.all.return_value = [ # execute().all() for segments_info (multi-column)
execute_result = MagicMock()
execute_result.all.return_value = [
("node-1", "segment-1", 2), ("node-1", "segment-1", 2),
("node-2", "segment-2", 5), ("node-2", "segment-2", 5),
] ]
child_query = MagicMock() mock_db.session.execute.return_value = execute_result
child_query.where.return_value.all.return_value = [("child-1",)] # scalars() for child_node_ids
delete_query = MagicMock() mock_db.session.scalars.return_value.all.return_value = ["child-1"]
delete_query.where.return_value.delete.return_value = 2
mock_db.session.query.side_effect = [segments_query, child_query, delete_query]
SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset) SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset)
@ -748,7 +736,6 @@ class TestSegmentServiceMutations:
["segment-1", "segment-2"], ["segment-1", "segment-2"],
["child-1"], ["child-1"],
) )
delete_query.where.return_value.delete.assert_called_once()
mock_db.session.commit.assert_called_once() mock_db.session.commit.assert_called_once()
def test_update_segments_status_enables_only_segments_without_indexing_cache(self): def test_update_segments_status_enables_only_segments_without_indexing_cache(self):
@ -868,7 +855,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.dataset_service.VectorService") as vector_service, patch("services.dataset_service.VectorService") as vector_service,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment( result = SegmentService.update_segment(
SegmentUpdateArgs(content="question", answer="new answer"), SegmentUpdateArgs(content="question", answer="new answer"),
@ -902,11 +889,8 @@ class TestSegmentServiceAdditionalRegenerationBranches:
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock() mock_db.session.scalar.return_value = None
summary_query.where.return_value.first.return_value = None mock_db.session.get.return_value = refreshed_segment
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [summary_query, refreshed_query]
result = SegmentService.update_segment( result = SegmentService.update_segment(
SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]), SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]),
@ -951,13 +935,10 @@ class TestSegmentServiceAdditionalRegenerationBranches:
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance
update_summary.side_effect = RuntimeError("summary failed") update_summary.side_effect = RuntimeError("summary failed")
processing_rule_query = MagicMock() # get calls: processing_rule, then refreshed_segment
processing_rule_query.where.return_value.first.return_value = processing_rule mock_db.session.get.side_effect = [processing_rule, refreshed_segment]
summary_query = MagicMock() # scalar call: existing_summary
summary_query.where.return_value.first.return_value = existing_summary mock_db.session.scalar.return_value = existing_summary
refreshed_query = MagicMock()
refreshed_query.where.return_value.first.return_value = refreshed_segment
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
result = SegmentService.update_segment( result = SegmentService.update_segment(
SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"), SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"),
@ -1000,7 +981,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.dataset_service.VectorService") as vector_service, patch("services.dataset_service.VectorService") as vector_service,
): ):
mock_redis.get.return_value = None mock_redis.get.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment mock_db.session.get.return_value = refreshed_segment
result = SegmentService.update_segment( result = SegmentService.update_segment(
SegmentUpdateArgs(content="same content", regenerate_child_chunks=True), SegmentUpdateArgs(content="same content", regenerate_child_chunks=True),