mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: select in dataset_service (SegmentService and remaining cla… (#34547)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
779e6b8e0b
commit
eca0cdc7a9
@ -14,7 +14,7 @@ from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import exists, func, select, update
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -3152,10 +3152,8 @@ class SegmentService:
|
||||
lock_name = f"add_segment_lock_document_id_{document.id}"
|
||||
try:
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3207,7 +3205,7 @@ class SegmentService:
|
||||
segment_document.status = SegmentStatus.ERROR
|
||||
segment_document.error = str(e)
|
||||
db.session.commit()
|
||||
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
|
||||
segment = db.session.get(DocumentSegment, segment_document.id)
|
||||
return segment
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
@ -3230,10 +3228,8 @@ class SegmentService:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document.id)
|
||||
.scalar()
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id)
|
||||
)
|
||||
pre_segment_data_list = []
|
||||
segment_data_list = []
|
||||
@ -3378,11 +3374,7 @@ class SegmentService:
|
||||
else:
|
||||
raise ValueError("The knowledge base index technique is not high quality!")
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
@ -3400,13 +3392,13 @@ class SegmentService:
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
existing_summary = db.session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Check if summary has changed
|
||||
@ -3482,11 +3474,7 @@ class SegmentService:
|
||||
else:
|
||||
raise ValueError("The knowledge base index technique is not high quality!")
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id)
|
||||
if processing_rule:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
@ -3498,13 +3486,13 @@ class SegmentService:
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
existing_summary = db.session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if args.summary is None:
|
||||
@ -3570,7 +3558,7 @@ class SegmentService:
|
||||
segment.status = SegmentStatus.ERROR
|
||||
segment.error = str(e)
|
||||
db.session.commit()
|
||||
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
|
||||
new_segment = db.session.get(DocumentSegment, segment.id)
|
||||
if not new_segment:
|
||||
raise ValueError("new_segment is not found")
|
||||
return new_segment
|
||||
@ -3590,15 +3578,14 @@ class SegmentService:
|
||||
# Get child chunk IDs before parent segment is deleted
|
||||
child_node_ids = []
|
||||
if segment.index_node_id:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
child_node_ids = list(
|
||||
db.session.scalars(
|
||||
select(ChildChunk.index_node_id).where(
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
|
||||
delete_segment_from_index_task.delay(
|
||||
[segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids
|
||||
@ -3617,17 +3604,14 @@ class SegmentService:
|
||||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
segments_info = (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count)
|
||||
.where(
|
||||
segments_info = db.session.execute(
|
||||
select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments_info:
|
||||
return
|
||||
@ -3639,15 +3623,16 @@ class SegmentService:
|
||||
# Get child chunk IDs before parent segments are deleted
|
||||
child_node_ids = []
|
||||
if index_node_ids:
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk.index_node_id)
|
||||
.where(
|
||||
ChildChunk.segment_id.in_(segment_db_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]]
|
||||
child_node_ids = [
|
||||
nid
|
||||
for nid in db.session.scalars(
|
||||
select(ChildChunk.index_node_id).where(
|
||||
ChildChunk.segment_id.in_(segment_db_ids),
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
)
|
||||
).all()
|
||||
if nid
|
||||
]
|
||||
|
||||
# Start async cleanup with both parent and child node IDs
|
||||
if index_node_ids or child_node_ids:
|
||||
@ -3663,7 +3648,7 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
|
||||
# Delete database records
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
|
||||
db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)))
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
@ -3737,15 +3722,13 @@ class SegmentService:
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
index_node_id = str(uuid.uuid4())
|
||||
index_node_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
db.session.query(func.max(ChildChunk.position))
|
||||
.where(
|
||||
max_position = db.session.scalar(
|
||||
select(func.max(ChildChunk.position)).where(
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
ChildChunk.document_id == document.id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
child_chunk = ChildChunk(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@ -3905,10 +3888,8 @@ class SegmentService:
|
||||
@classmethod
|
||||
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
|
||||
"""Get a child chunk by its ID."""
|
||||
result = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
|
||||
.first()
|
||||
result = db.session.scalar(
|
||||
select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
return result if isinstance(result, ChildChunk) else None
|
||||
|
||||
@ -3943,10 +3924,10 @@ class SegmentService:
|
||||
@classmethod
|
||||
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
|
||||
"""Get a segment by its ID."""
|
||||
result = (
|
||||
db.session.query(DocumentSegment)
|
||||
result = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
return result if isinstance(result, DocumentSegment) else None
|
||||
|
||||
@ -3989,15 +3970,15 @@ class DatasetCollectionBindingService:
|
||||
def get_dataset_collection_binding(
|
||||
cls, provider_name: str, model_name: str, collection_type: str = "dataset"
|
||||
) -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = db.session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name == provider_name,
|
||||
DatasetCollectionBinding.model_name == model_name,
|
||||
DatasetCollectionBinding.type == collection_type,
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not dataset_collection_binding:
|
||||
@ -4015,13 +3996,13 @@ class DatasetCollectionBindingService:
|
||||
def get_dataset_collection_binding_by_id_and_type(
|
||||
cls, collection_binding_id: str, collection_type: str = "dataset"
|
||||
) -> DatasetCollectionBinding:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
dataset_collection_binding = db.session.scalar(
|
||||
select(DatasetCollectionBinding)
|
||||
.where(
|
||||
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
|
||||
)
|
||||
.order_by(DatasetCollectionBinding.created_at)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
raise ValueError("Dataset collection binding not found")
|
||||
@ -4043,7 +4024,7 @@ class DatasetPermissionService:
|
||||
@classmethod
|
||||
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
|
||||
try:
|
||||
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
|
||||
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
|
||||
permissions = []
|
||||
for user in user_list:
|
||||
permission = DatasetPermission(
|
||||
@ -4079,7 +4060,7 @@ class DatasetPermissionService:
|
||||
@classmethod
|
||||
def clear_partial_member_list(cls, dataset_id):
|
||||
try:
|
||||
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
|
||||
db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id))
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
|
||||
@ -1607,7 +1607,7 @@ class TestDatasetCollectionBindingService:
|
||||
binding = SimpleNamespace(id="binding-1")
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
|
||||
mock_db.session.scalar.return_value = binding
|
||||
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
|
||||
|
||||
@ -1619,10 +1619,11 @@ class TestDatasetCollectionBindingService:
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.select"),
|
||||
patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls,
|
||||
patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"),
|
||||
):
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset")
|
||||
|
||||
@ -1638,7 +1639,7 @@ class TestDatasetCollectionBindingService:
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self):
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
|
||||
@ -1647,7 +1648,7 @@ class TestDatasetCollectionBindingService:
|
||||
binding = SimpleNamespace(id="binding-1")
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
|
||||
mock_db.session.scalar.return_value = binding
|
||||
|
||||
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
|
||||
|
||||
@ -1673,7 +1674,7 @@ class TestDatasetPermissionService:
|
||||
[{"user_id": "user-1"}, {"user_id": "user-2"}],
|
||||
)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
|
||||
mock_db.session.execute.assert_called()
|
||||
mock_db.session.add_all.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@ -1744,12 +1745,12 @@ class TestDatasetPermissionService:
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
DatasetPermissionService.clear_partial_member_list("dataset-1")
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
|
||||
mock_db.session.execute.assert_called()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_clear_partial_member_list_rolls_back_on_exception(self):
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom")
|
||||
mock_db.session.execute.side_effect = RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
DatasetPermissionService.clear_partial_member_list("dataset-1")
|
||||
|
||||
@ -49,7 +49,7 @@ class TestSegmentServiceChildChunks:
|
||||
patch("services.dataset_service.VectorService") as vector_service,
|
||||
):
|
||||
mock_redis.lock.return_value = _make_lock_context()
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 2
|
||||
mock_db.session.scalar.return_value = 2
|
||||
|
||||
child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset)
|
||||
|
||||
@ -75,7 +75,7 @@ class TestSegmentServiceChildChunks:
|
||||
patch("services.dataset_service.VectorService") as vector_service,
|
||||
):
|
||||
mock_redis.lock.return_value = _make_lock_context()
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed")
|
||||
|
||||
with pytest.raises(ChildChunkIndexingError, match="vector failed"):
|
||||
@ -247,13 +247,13 @@ class TestSegmentServiceQueries:
|
||||
child_chunk = _make_child_chunk()
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk
|
||||
mock_db.session.scalar.return_value = child_chunk
|
||||
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
|
||||
|
||||
assert result is child_chunk
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
|
||||
mock_db.session.scalar.return_value = SimpleNamespace()
|
||||
result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1")
|
||||
|
||||
assert result is None
|
||||
@ -295,13 +295,13 @@ class TestSegmentServiceQueries:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = segment
|
||||
mock_db.session.scalar.return_value = segment
|
||||
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
|
||||
|
||||
assert result is segment
|
||||
|
||||
with patch("services.dataset_service.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace()
|
||||
mock_db.session.scalar.return_value = SimpleNamespace()
|
||||
result = SegmentService.get_segment_by_id("segment-1", "tenant-1")
|
||||
|
||||
assert result is None
|
||||
@ -401,11 +401,8 @@ class TestSegmentServiceMutations:
|
||||
):
|
||||
mock_redis.lock.return_value = _make_lock_context()
|
||||
|
||||
max_position_query = MagicMock()
|
||||
max_position_query.where.return_value.scalar.return_value = 2
|
||||
refresh_query = MagicMock()
|
||||
refresh_query.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.query.side_effect = [max_position_query, refresh_query]
|
||||
mock_db.session.scalar.return_value = 2
|
||||
mock_db.session.get.return_value = refreshed_segment
|
||||
|
||||
def add_side_effect(obj):
|
||||
if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None:
|
||||
@ -461,7 +458,7 @@ class TestSegmentServiceMutations:
|
||||
):
|
||||
mock_redis.lock.return_value = _make_lock_context()
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
|
||||
mock_db.session.scalar.return_value = 1
|
||||
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
|
||||
|
||||
result = SegmentService.multi_create_segment(segments, document, dataset)
|
||||
@ -538,7 +535,7 @@ class TestSegmentServiceMutations:
|
||||
patch("services.dataset_service.VectorService") as vector_service,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.get.return_value = refreshed_segment
|
||||
|
||||
result = SegmentService.update_segment(args, segment, document, dataset)
|
||||
|
||||
@ -574,13 +571,10 @@ class TestSegmentServiceMutations:
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance
|
||||
|
||||
processing_rule_query = MagicMock()
|
||||
processing_rule_query.where.return_value.first.return_value = processing_rule
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = existing_summary
|
||||
refreshed_query = MagicMock()
|
||||
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
|
||||
# get calls: processing_rule, then refreshed_segment
|
||||
mock_db.session.get.side_effect = [processing_rule, refreshed_segment]
|
||||
# scalar call: existing_summary
|
||||
mock_db.session.scalar.return_value = existing_summary
|
||||
|
||||
result = SegmentService.update_segment(args, segment, document, dataset)
|
||||
|
||||
@ -621,11 +615,8 @@ class TestSegmentServiceMutations:
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = existing_summary
|
||||
refreshed_query = MagicMock()
|
||||
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.query.side_effect = [summary_query, refreshed_query]
|
||||
mock_db.session.scalar.return_value = existing_summary
|
||||
mock_db.session.get.return_value = refreshed_segment
|
||||
|
||||
result = SegmentService.update_segment(args, segment, document, dataset)
|
||||
|
||||
@ -664,11 +655,8 @@ class TestSegmentServiceMutations:
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = existing_summary
|
||||
refreshed_query = MagicMock()
|
||||
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.query.side_effect = [summary_query, refreshed_query]
|
||||
mock_db.session.scalar.return_value = existing_summary
|
||||
mock_db.session.get.return_value = refreshed_segment
|
||||
|
||||
result = SegmentService.update_segment(args, segment, document, dataset)
|
||||
|
||||
@ -688,7 +676,7 @@ class TestSegmentServiceMutations:
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)]
|
||||
mock_db.session.scalars.return_value.all.return_value = ["child-1", "child-2"]
|
||||
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
|
||||
@ -727,15 +715,15 @@ class TestSegmentServiceMutations:
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as delete_task,
|
||||
):
|
||||
segments_query = MagicMock()
|
||||
segments_query.with_entities.return_value.where.return_value.all.return_value = [
|
||||
# execute().all() for segments_info (multi-column)
|
||||
execute_result = MagicMock()
|
||||
execute_result.all.return_value = [
|
||||
("node-1", "segment-1", 2),
|
||||
("node-2", "segment-2", 5),
|
||||
]
|
||||
child_query = MagicMock()
|
||||
child_query.where.return_value.all.return_value = [("child-1",)]
|
||||
delete_query = MagicMock()
|
||||
delete_query.where.return_value.delete.return_value = 2
|
||||
mock_db.session.query.side_effect = [segments_query, child_query, delete_query]
|
||||
mock_db.session.execute.return_value = execute_result
|
||||
# scalars() for child_node_ids
|
||||
mock_db.session.scalars.return_value.all.return_value = ["child-1"]
|
||||
|
||||
SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset)
|
||||
|
||||
@ -748,7 +736,6 @@ class TestSegmentServiceMutations:
|
||||
["segment-1", "segment-2"],
|
||||
["child-1"],
|
||||
)
|
||||
delete_query.where.return_value.delete.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_update_segments_status_enables_only_segments_without_indexing_cache(self):
|
||||
@ -868,7 +855,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
||||
patch("services.dataset_service.VectorService") as vector_service,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.get.return_value = refreshed_segment
|
||||
|
||||
result = SegmentService.update_segment(
|
||||
SegmentUpdateArgs(content="question", answer="new answer"),
|
||||
@ -902,11 +889,8 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = None
|
||||
refreshed_query = MagicMock()
|
||||
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.query.side_effect = [summary_query, refreshed_query]
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.get.return_value = refreshed_segment
|
||||
|
||||
result = SegmentService.update_segment(
|
||||
SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]),
|
||||
@ -951,13 +935,10 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance
|
||||
update_summary.side_effect = RuntimeError("summary failed")
|
||||
|
||||
processing_rule_query = MagicMock()
|
||||
processing_rule_query.where.return_value.first.return_value = processing_rule
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = existing_summary
|
||||
refreshed_query = MagicMock()
|
||||
refreshed_query.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query]
|
||||
# get calls: processing_rule, then refreshed_segment
|
||||
mock_db.session.get.side_effect = [processing_rule, refreshed_segment]
|
||||
# scalar call: existing_summary
|
||||
mock_db.session.scalar.return_value = existing_summary
|
||||
|
||||
result = SegmentService.update_segment(
|
||||
SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"),
|
||||
@ -1000,7 +981,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
||||
patch("services.dataset_service.VectorService") as vector_service,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment
|
||||
mock_db.session.get.return_value = refreshed_segment
|
||||
|
||||
result = SegmentService.update_segment(
|
||||
SegmentUpdateArgs(content="same content", regenerate_child_chunks=True),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user