refactor(services): migrate summary_index_service to SQLAlchemy 2.0 select() API (#34971)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wdeveloper16 2026-04-12 03:37:16 +02:00 committed by GitHub
parent 510120410b
commit 440602f52a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 171 additions and 295 deletions

View File

@ -8,6 +8,7 @@ from typing import TypedDict, cast
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
@ -109,8 +110,13 @@ class SummaryIndexService:
"""
with session_factory.create_session() as session:
# Check if summary record already exists
existing_summary = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
existing_summary = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if existing_summary:
@ -309,8 +315,10 @@ class SummaryIndexService:
summary_record_id,
segment.id,
)
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(DocumentSegmentSummary.id == summary_record_id)
.limit(1)
)
if not summary_record_in_session:
@ -323,10 +331,13 @@ class SummaryIndexService:
dataset.id,
segment.id,
)
summary_record_in_session = (
session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if not summary_record_in_session:
@ -487,8 +498,10 @@ class SummaryIndexService:
with session_factory.create_session() as error_session:
# Try to find the record by id first
# Note: Using assignment only (no type annotation) to avoid redeclaration error
summary_record_in_session = (
error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
summary_record_in_session = error_session.scalar(
select(DocumentSegmentSummary)
.where(DocumentSegmentSummary.id == summary_record_id)
.limit(1)
)
if not summary_record_in_session:
# Try to find by chunk_id and dataset_id
@ -500,10 +513,13 @@ class SummaryIndexService:
dataset.id,
segment.id,
)
summary_record_in_session = (
error_session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record_in_session = error_session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record_in_session:
@ -551,14 +567,12 @@ class SummaryIndexService:
with session_factory.create_session() as session:
# Query existing summary records
existing_summaries = (
session.query(DocumentSegmentSummary)
.filter(
existing_summaries = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset.id,
)
.all()
)
).all()
existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
# Create or update records
@ -603,8 +617,13 @@ class SummaryIndexService:
error: Error message
"""
with session_factory.create_session() as session:
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@ -639,8 +658,13 @@ class SummaryIndexService:
with session_factory.create_session() as session:
try:
# Get or refresh summary record in this session
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if not summary_record_in_session:
@ -710,8 +734,13 @@ class SummaryIndexService:
except Exception as e:
logger.exception("Failed to generate summary for segment %s", segment.id)
# Update summary record with error status
summary_record_in_session = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record_in_session = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record_in_session:
summary_record_in_session.status = SummaryStatus.ERROR
@ -769,17 +798,17 @@ class SummaryIndexService:
with session_factory.create_session() as session:
# Query segments (only enabled segments)
query = session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True, # Only generate summaries for enabled segments
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments
)
if segment_ids:
query = query.filter(DocumentSegment.id.in_(segment_ids))
stmt = stmt.where(DocumentSegment.id.in_(segment_ids))
segments = query.all()
segments = list(session.scalars(stmt).all())
if not segments:
logger.info("No segments found for document %s", document.id)
@ -848,15 +877,15 @@ class SummaryIndexService:
from libs.datetime_utils import naive_utc_now
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=True, # Only disable enabled summaries
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.dataset_id == dataset.id,
DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@ -911,15 +940,15 @@ class SummaryIndexService:
return
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=False, # Only enable disabled summaries
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.dataset_id == dataset.id,
DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@ -935,13 +964,13 @@ class SummaryIndexService:
enabled_count = 0
for summary in summaries:
# Get the original segment
segment = (
session.query(DocumentSegment)
.filter_by(
id=summary.chunk_id,
dataset_id=dataset.id,
segment = session.scalar(
select(DocumentSegment)
.where(
DocumentSegment.id == summary.chunk_id,
DocumentSegment.dataset_id == dataset.id,
)
.first()
.limit(1)
)
# Summary.enabled stays in sync with chunk.enabled,
@ -988,12 +1017,12 @@ class SummaryIndexService:
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
"""
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
summaries = session.scalars(stmt).all()
if not summaries:
return
@ -1046,10 +1075,13 @@ class SummaryIndexService:
# Check if summary_content is empty (whitespace-only strings are considered empty)
if not summary_content or not summary_content.strip():
# If summary is empty, only delete existing summary vector and record
summary_record = (
session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@ -1077,8 +1109,13 @@ class SummaryIndexService:
return None
# Find existing summary record
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
@ -1162,8 +1199,13 @@ class SummaryIndexService:
except Exception as e:
logger.exception("Failed to update summary for segment %s", segment.id)
# Update summary record with error status if it exists
summary_record = (
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.limit(1)
)
if summary_record:
summary_record.status = SummaryStatus.ERROR
@ -1185,14 +1227,14 @@ class SummaryIndexService:
DocumentSegmentSummary instance if found, None otherwise
"""
with session_factory.create_session() as session:
return (
session.query(DocumentSegmentSummary)
return session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
.first()
.limit(1)
)
@staticmethod
@ -1211,15 +1253,13 @@ class SummaryIndexService:
return {}
with session_factory.create_session() as session:
summary_records = (
session.query(DocumentSegmentSummary)
.where(
summary_records = session.scalars(
select(DocumentSegmentSummary).where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
.all()
)
).all()
return {summary.chunk_id: summary for summary in summary_records}
@ -1239,16 +1279,16 @@ class SummaryIndexService:
List of DocumentSegmentSummary instances (only enabled summaries)
"""
with session_factory.create_session() as session:
query = session.query(DocumentSegmentSummary).filter(
stmt = select(DocumentSegmentSummary).where(
DocumentSegmentSummary.document_id == document_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
return query.all()
return list(session.scalars(stmt).all())
@staticmethod
def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None:
@ -1265,16 +1305,15 @@ class SummaryIndexService:
"""
# Get all segments for this document (excluding qa_model and re_segment)
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment.id)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
.all()
segment_ids = list(
session.scalars(
select(DocumentSegment.id).where(
DocumentSegment.document_id == document_id,
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
).all()
)
segment_ids = [seg.id for seg in segments]
if not segment_ids:
return None
@ -1312,15 +1351,13 @@ class SummaryIndexService:
# Get all segments for these documents (excluding qa_model and re_segment)
with session_factory.create_session() as session:
segments = (
session.query(DocumentSegment.id, DocumentSegment.document_id)
.where(
segments = session.execute(
select(DocumentSegment.id, DocumentSegment.document_id).where(
DocumentSegment.document_id.in_(document_ids),
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == tenant_id,
)
.all()
)
).all()
# Group segments by document_id
document_segments_map: dict[str, list[str]] = {}

View File

@ -124,10 +124,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
existing.disabled_by = "u"
session = MagicMock(name="session")
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = existing
session.query.return_value = query
session.scalar.return_value = existing
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -149,10 +146,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None:
session = MagicMock(name="session")
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -234,10 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat
# New session used after vectorization succeeds (record not found by id nor chunk_id).
session = MagicMock(name="session")
q1 = MagicMock()
q1.filter_by.return_value = q1
q1.first.side_effect = [None, None]
session.query.return_value = q1
session.scalar.side_effect = [None, None]
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -267,10 +258,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes
# error_session should find record and commit status update
error_session = MagicMock(name="error_session")
q = MagicMock()
q.filter_by.return_value = q
q.first.return_value = summary
error_session.query.return_value = q
error_session.scalar.return_value = summary
create_session_mock = MagicMock(return_value=_SessionContext(error_session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -302,10 +290,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo
existing.enabled = False
session = MagicMock()
query = MagicMock()
query.filter.return_value = query
query.all.return_value = [existing]
session.query.return_value = query
session.scalars.return_value.all.return_value = [existing]
monkeypatch.setattr(
summary_module,
@ -324,10 +309,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon
record = _summary_record()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -346,10 +328,7 @@ def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch)
record = _summary_record(summary_content="")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
@ -373,10 +352,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch
record = _summary_record(summary_content="")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
@ -415,10 +391,7 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch
existing = _summary_record(summary_content="old", node_id="old-node")
existing.id = "other-id"
session = MagicMock(name="session")
q = MagicMock()
q.filter_by.return_value = q
q.first.side_effect = [None, existing] # miss by id, hit by chunk_id
session.query.return_value = q
session.scalar.side_effect = [None, existing] # miss by id, hit by chunk_id
monkeypatch.setattr(
summary_module,
"session_factory",
@ -448,10 +421,7 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte
existing = _summary_record(summary_content="old", node_id="old-node")
session = MagicMock(name="session")
q = MagicMock()
q.filter_by.return_value = q
q.first.return_value = existing # hit by id
session.query.return_value = q
session.scalar.return_value = existing # hit by id
monkeypatch.setattr(
summary_module,
"session_factory",
@ -487,10 +457,7 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon
return None
error_session = MagicMock()
q = MagicMock()
q.filter_by.return_value = q
q.first.return_value = summary
error_session.query.return_value = q
error_session.scalar.return_value = summary
create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)])
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -516,21 +483,17 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc
)
session = MagicMock()
q = MagicMock()
q.filter_by.return_value = q
q.first.side_effect = [None, None] # miss by id and chunk_id
session.query.return_value = q
session.scalar.side_effect = [None, None] # miss by id and chunk_id
error_session = MagicMock()
eq = MagicMock()
eq.filter_by.return_value = eq
eq.first.return_value = summary
error_session.query.return_value = eq
error_session.scalar.return_value = summary
create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)])
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
# Force the created record to be None so the "should not be None" guard triggers.
# Also mock select() so SQLAlchemy doesn't validate the mocked DocumentSegmentSummary as a real column clause.
monkeypatch.setattr(summary_module, "select", MagicMock(return_value=MagicMock()))
monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None))
with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"):
@ -554,10 +517,7 @@ def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_
)
error_session = MagicMock(name="error_session")
q = MagicMock()
q.filter_by.return_value = q
q.first.side_effect = [None, None] # not found by id, not found by chunk_id
error_session.query.return_value = q
error_session.scalar.side_effect = [None, None] # not found by id, not found by chunk_id
monkeypatch.setattr(
summary_module,
@ -577,10 +537,7 @@ def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.Monk
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -599,10 +556,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -646,11 +600,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py
seg2.id = "seg-2"
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [seg1, seg2]
session.query.return_value = query
session.scalars.return_value.all.return_value = [seg1, seg2]
monkeypatch.setattr(
summary_module,
@ -678,11 +628,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch:
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -702,11 +648,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
seg = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [seg]
session.query.return_value = query
session.scalars.return_value.all.return_value = [seg]
monkeypatch.setattr(
summary_module,
"session_factory",
@ -723,7 +665,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
segment_ids=[seg.id],
only_parent_chunks=True,
)
query.filter.assert_called()
session.scalars.assert_called()
def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None:
@ -732,11 +674,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
summary2 = _summary_record(summary_content="s", node_id=None)
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [summary1, summary2]
session.query.return_value = query
session.scalars.return_value.all.return_value = [summary1, summary2]
monkeypatch.setattr(
summary_module,
@ -761,11 +699,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -793,21 +727,8 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
segment.status = SegmentStatus.COMPLETED
session = MagicMock()
summary_query = MagicMock()
summary_query.filter_by.return_value = summary_query
summary_query.filter.return_value = summary_query
summary_query.all.return_value = [summary]
seg_query = MagicMock()
seg_query.filter_by.return_value = seg_query
seg_query.first.return_value = segment
def query_side_effect(model: object) -> MagicMock:
if model is summary_module.DocumentSegmentSummary:
return summary_query
return seg_query
session.query.side_effect = query_side_effect
session.scalars.return_value.all.return_value = [summary]
session.scalar.return_value = segment
monkeypatch.setattr(
summary_module,
@ -826,11 +747,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -860,21 +777,9 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect
good_segment.status = SegmentStatus.COMPLETED
session = MagicMock()
summary_query = MagicMock()
summary_query.filter_by.return_value = summary_query
summary_query.filter.return_value = summary_query
summary_query.all.return_value = [summary1, summary2, summary3]
session.scalars.return_value.all.return_value = [summary1, summary2, summary3]
session.scalar.side_effect = [bad_segment, good_segment, good_segment]
seg_query = MagicMock()
seg_query.filter_by.return_value = seg_query
seg_query.first.side_effect = [bad_segment, good_segment, good_segment]
def query_side_effect(model: object) -> MagicMock:
if model is summary_module.DocumentSegmentSummary:
return summary_query
return seg_query
session.query.side_effect = query_side_effect
monkeypatch.setattr(
summary_module,
"session_factory",
@ -895,11 +800,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
summary = _summary_record(summary_content="sum", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = [summary]
session.query.return_value = query
session.scalars.return_value.all.return_value = [summary]
vector_instance = MagicMock()
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
@ -918,11 +819,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
dataset = _dataset()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.filter.return_value = query
query.all.return_value = []
session.query.return_value = query
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -946,10 +843,7 @@ def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch:
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
vector_instance = MagicMock()
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
@ -971,10 +865,7 @@ def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatc
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -996,10 +887,7 @@ def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: py
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1015,10 +903,7 @@ def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch:
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
vector_instance = MagicMock()
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
@ -1044,10 +929,7 @@ def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: py
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1073,10 +955,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1095,10 +974,7 @@ def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.Monke
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1122,10 +998,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
record = _summary_record(summary_content="old", node_id="n1")
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = record
session.query.return_value = query
session.scalar.return_value = record
session.flush.side_effect = RuntimeError("flush boom")
monkeypatch.setattr(
summary_module,
@ -1143,25 +1016,9 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None:
record = _summary_record(summary_content="sum", node_id="n1")
session = MagicMock()
session.scalar.return_value = record
session.scalars.return_value.all.return_value = [record]
q1 = MagicMock()
q1.where.return_value = q1
q1.first.return_value = record
q2 = MagicMock()
q2.filter.return_value = q2
q2.all.return_value = [record]
def query_side_effect(model: object) -> MagicMock:
if model is summary_module.DocumentSegmentSummary:
# first call used by get_segment_summary, second by get_document_summaries
if not hasattr(query_side_effect, "_called"):
query_side_effect._called = True # type: ignore[attr-defined]
return q1
return q2
return MagicMock()
session.query.side_effect = query_side_effect
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1178,10 +1035,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
record2 = _summary_record()
record2.chunk_id = "seg-2"
session = MagicMock()
q = MagicMock()
q.where.return_value = q
q.all.return_value = [record1, record2]
session.query.return_value = q
session.scalars.return_value.all.return_value = [record1, record2]
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1194,10 +1048,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None:
session = MagicMock()
q = MagicMock()
q.where.return_value = q
q.all.return_value = []
session.query.return_value = q
session.scalars.return_value.all.return_value = []
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1212,10 +1063,7 @@ def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.Monk
def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None:
session = MagicMock()
q = MagicMock()
q.where.return_value = q
q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
session.query.return_value = q
session.execute.return_value.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
monkeypatch.setattr(
summary_module,
"session_factory",
@ -1237,10 +1085,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro
segment = _segment()
session = MagicMock()
query = MagicMock()
query.filter_by.return_value = query
query.first.return_value = None
session.query.return_value = query
session.scalar.return_value = None
monkeypatch.setattr(
summary_module,
@ -1267,10 +1112,7 @@ def test_get_segments_summaries_empty_list() -> None:
def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None:
seg_row = SimpleNamespace(id="seg-1", document_id="doc-1")
session = MagicMock()
query = MagicMock()
query.where.return_value = query
query.all.return_value = [SimpleNamespace(id="seg-1")]
session.query.return_value = query
session.scalars.return_value.all.return_value = ["seg-1"] # get_document_summary_index_status returns IDs
create_session_mock = MagicMock(return_value=_SessionContext(session))
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
@ -1283,11 +1125,8 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt
assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING"
# Multiple docs
query2 = MagicMock()
query2.where.return_value = query2
query2.all.return_value = [seg_row]
session2 = MagicMock()
session2.query.return_value = query2
session2.execute.return_value.all.return_value = [seg_row] # get_documents_summary_index_status uses execute
monkeypatch.setattr(
summary_module,
"session_factory",