mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
refactor(services): migrate summary_index_service to SQLAlchemy 2.0 select() API (#34971)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
510120410b
commit
440602f52a
@ -8,6 +8,7 @@ from typing import TypedDict, cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -109,8 +110,13 @@ class SummaryIndexService:
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
# Check if summary record already exists
|
||||
existing_summary = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
existing_summary = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if existing_summary:
|
||||
@ -309,8 +315,10 @@ class SummaryIndexService:
|
||||
summary_record_id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(DocumentSegmentSummary.id == summary_record_id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@ -323,10 +331,13 @@ class SummaryIndexService:
|
||||
dataset.id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@ -487,8 +498,10 @@ class SummaryIndexService:
|
||||
with session_factory.create_session() as error_session:
|
||||
# Try to find the record by id first
|
||||
# Note: Using assignment only (no type annotation) to avoid redeclaration error
|
||||
summary_record_in_session = (
|
||||
error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first()
|
||||
summary_record_in_session = error_session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(DocumentSegmentSummary.id == summary_record_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not summary_record_in_session:
|
||||
# Try to find by chunk_id and dataset_id
|
||||
@ -500,10 +513,13 @@ class SummaryIndexService:
|
||||
dataset.id,
|
||||
segment.id,
|
||||
)
|
||||
summary_record_in_session = (
|
||||
error_session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record_in_session = error_session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record_in_session:
|
||||
@ -551,14 +567,12 @@ class SummaryIndexService:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
# Query existing summary records
|
||||
existing_summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
existing_summaries = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
|
||||
|
||||
# Create or update records
|
||||
@ -603,8 +617,13 @@ class SummaryIndexService:
|
||||
error: Error message
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@ -639,8 +658,13 @@ class SummaryIndexService:
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
# Get or refresh summary record in this session
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record_in_session:
|
||||
@ -710,8 +734,13 @@ class SummaryIndexService:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate summary for segment %s", segment.id)
|
||||
# Update summary record with error status
|
||||
summary_record_in_session = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record_in_session = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if summary_record_in_session:
|
||||
summary_record_in_session.status = SummaryStatus.ERROR
|
||||
@ -769,17 +798,17 @@ class SummaryIndexService:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
# Query segments (only enabled segments)
|
||||
query = session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True, # Only generate summaries for enabled segments
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegment.id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegment.id.in_(segment_ids))
|
||||
|
||||
segments = query.all()
|
||||
segments = list(session.scalars(stmt).all())
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
@ -848,15 +877,15 @@ class SummaryIndexService:
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=True, # Only disable enabled summaries
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@ -911,15 +940,15 @@ class SummaryIndexService:
|
||||
return
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=False, # Only enable disabled summaries
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@ -935,13 +964,13 @@ class SummaryIndexService:
|
||||
enabled_count = 0
|
||||
for summary in summaries:
|
||||
# Get the original segment
|
||||
segment = (
|
||||
session.query(DocumentSegment)
|
||||
.filter_by(
|
||||
id=summary.chunk_id,
|
||||
dataset_id=dataset.id,
|
||||
segment = session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.id == summary.chunk_id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# Summary.enabled stays in sync with chunk.enabled,
|
||||
@ -988,12 +1017,12 @@ class SummaryIndexService:
|
||||
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
|
||||
stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
summaries = session.scalars(stmt).all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
@ -1046,10 +1075,13 @@ class SummaryIndexService:
|
||||
# Check if summary_content is empty (whitespace-only strings are considered empty)
|
||||
if not summary_content or not summary_content.strip():
|
||||
# If summary is empty, only delete existing summary vector and record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
|
||||
.first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@ -1077,8 +1109,13 @@ class SummaryIndexService:
|
||||
return None
|
||||
|
||||
# Find existing summary record
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
@ -1162,8 +1199,13 @@ class SummaryIndexService:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Update summary record with error status if it exists
|
||||
summary_record = (
|
||||
session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
@ -1185,14 +1227,14 @@ class SummaryIndexService:
|
||||
DocumentSegmentSummary instance if found, None otherwise
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
return (
|
||||
session.query(DocumentSegmentSummary)
|
||||
return session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1211,15 +1253,13 @@ class SummaryIndexService:
|
||||
return {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
summary_records = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
summary_records = session.scalars(
|
||||
select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
return {summary.chunk_id: summary for summary in summary_records}
|
||||
|
||||
@ -1239,16 +1279,16 @@ class SummaryIndexService:
|
||||
List of DocumentSegmentSummary instances (only enabled summaries)
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
query = session.query(DocumentSegmentSummary).filter(
|
||||
stmt = select(DocumentSegmentSummary).where(
|
||||
DocumentSegmentSummary.document_id == document_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
return query.all()
|
||||
return list(session.scalars(stmt).all())
|
||||
|
||||
@staticmethod
|
||||
def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None:
|
||||
@ -1265,16 +1305,15 @@ class SummaryIndexService:
|
||||
"""
|
||||
# Get all segments for this document (excluding qa_model and re_segment)
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment.id)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
segment_ids = list(
|
||||
session.scalars(
|
||||
select(DocumentSegment.id).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
).all()
|
||||
)
|
||||
segment_ids = [seg.id for seg in segments]
|
||||
|
||||
if not segment_ids:
|
||||
return None
|
||||
@ -1312,15 +1351,13 @@ class SummaryIndexService:
|
||||
|
||||
# Get all segments for these documents (excluding qa_model and re_segment)
|
||||
with session_factory.create_session() as session:
|
||||
segments = (
|
||||
session.query(DocumentSegment.id, DocumentSegment.document_id)
|
||||
.where(
|
||||
segments = session.execute(
|
||||
select(DocumentSegment.id, DocumentSegment.document_id).where(
|
||||
DocumentSegment.document_id.in_(document_ids),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group segments by document_id
|
||||
document_segments_map: dict[str, list[str]] = {}
|
||||
|
||||
@ -124,10 +124,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
|
||||
existing.disabled_by = "u"
|
||||
|
||||
session = MagicMock(name="session")
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = existing
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = existing
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -149,10 +146,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes
|
||||
|
||||
def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = MagicMock(name="session")
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -234,10 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat
|
||||
|
||||
# New session used after vectorization succeeds (record not found by id nor chunk_id).
|
||||
session = MagicMock(name="session")
|
||||
q1 = MagicMock()
|
||||
q1.filter_by.return_value = q1
|
||||
q1.first.side_effect = [None, None]
|
||||
session.query.return_value = q1
|
||||
session.scalar.side_effect = [None, None]
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -267,10 +258,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes
|
||||
|
||||
# error_session should find record and commit status update
|
||||
error_session = MagicMock(name="error_session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.return_value = summary
|
||||
error_session.query.return_value = q
|
||||
error_session.scalar.return_value = summary
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(error_session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -302,10 +290,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo
|
||||
existing.enabled = False
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [existing]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [existing]
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -324,10 +309,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon
|
||||
record = _summary_record()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -346,10 +328,7 @@ def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch)
|
||||
record = _summary_record(summary_content="")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -373,10 +352,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch
|
||||
record = _summary_record(summary_content="")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -415,10 +391,7 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch
|
||||
existing = _summary_record(summary_content="old", node_id="old-node")
|
||||
existing.id = "other-id"
|
||||
session = MagicMock(name="session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.side_effect = [None, existing] # miss by id, hit by chunk_id
|
||||
session.query.return_value = q
|
||||
session.scalar.side_effect = [None, existing] # miss by id, hit by chunk_id
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -448,10 +421,7 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte
|
||||
|
||||
existing = _summary_record(summary_content="old", node_id="old-node")
|
||||
session = MagicMock(name="session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.return_value = existing # hit by id
|
||||
session.query.return_value = q
|
||||
session.scalar.return_value = existing # hit by id
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -487,10 +457,7 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon
|
||||
return None
|
||||
|
||||
error_session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.return_value = summary
|
||||
error_session.query.return_value = q
|
||||
error_session.scalar.return_value = summary
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)])
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -516,21 +483,17 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.side_effect = [None, None] # miss by id and chunk_id
|
||||
session.query.return_value = q
|
||||
session.scalar.side_effect = [None, None] # miss by id and chunk_id
|
||||
|
||||
error_session = MagicMock()
|
||||
eq = MagicMock()
|
||||
eq.filter_by.return_value = eq
|
||||
eq.first.return_value = summary
|
||||
error_session.query.return_value = eq
|
||||
error_session.scalar.return_value = summary
|
||||
|
||||
create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)])
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
|
||||
# Force the created record to be None so the "should not be None" guard triggers.
|
||||
# Also mock select() so SQLAlchemy doesn't validate the mocked DocumentSegmentSummary as a real column clause.
|
||||
monkeypatch.setattr(summary_module, "select", MagicMock(return_value=MagicMock()))
|
||||
monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None))
|
||||
|
||||
with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"):
|
||||
@ -554,10 +517,7 @@ def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_
|
||||
)
|
||||
|
||||
error_session = MagicMock(name="error_session")
|
||||
q = MagicMock()
|
||||
q.filter_by.return_value = q
|
||||
q.first.side_effect = [None, None] # not found by id, not found by chunk_id
|
||||
error_session.query.return_value = q
|
||||
error_session.scalar.side_effect = [None, None] # not found by id, not found by chunk_id
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -577,10 +537,7 @@ def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.Monk
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -599,10 +556,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -646,11 +600,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py
|
||||
seg2.id = "seg-2"
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [seg1, seg2]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [seg1, seg2]
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -678,11 +628,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch:
|
||||
document.doc_form = IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -702,11 +648,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
|
||||
seg = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [seg]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [seg]
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -723,7 +665,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu
|
||||
segment_ids=[seg.id],
|
||||
only_parent_chunks=True,
|
||||
)
|
||||
query.filter.assert_called()
|
||||
session.scalars.assert_called()
|
||||
|
||||
|
||||
def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -732,11 +674,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
|
||||
summary2 = _summary_record(summary_content="s", node_id=None)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [summary1, summary2]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [summary1, summary2]
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -761,11 +699,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch:
|
||||
def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _dataset()
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -793,21 +727,8 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
|
||||
segment.status = SegmentStatus.COMPLETED
|
||||
|
||||
session = MagicMock()
|
||||
summary_query = MagicMock()
|
||||
summary_query.filter_by.return_value = summary_query
|
||||
summary_query.filter.return_value = summary_query
|
||||
summary_query.all.return_value = [summary]
|
||||
|
||||
seg_query = MagicMock()
|
||||
seg_query.filter_by.return_value = seg_query
|
||||
seg_query.first.return_value = segment
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is summary_module.DocumentSegmentSummary:
|
||||
return summary_query
|
||||
return seg_query
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
session.scalars.return_value.all.return_value = [summary]
|
||||
session.scalar.return_value = segment
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -826,11 +747,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt
|
||||
def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _dataset()
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -860,21 +777,9 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect
|
||||
good_segment.status = SegmentStatus.COMPLETED
|
||||
|
||||
session = MagicMock()
|
||||
summary_query = MagicMock()
|
||||
summary_query.filter_by.return_value = summary_query
|
||||
summary_query.filter.return_value = summary_query
|
||||
summary_query.all.return_value = [summary1, summary2, summary3]
|
||||
session.scalars.return_value.all.return_value = [summary1, summary2, summary3]
|
||||
session.scalar.side_effect = [bad_segment, good_segment, good_segment]
|
||||
|
||||
seg_query = MagicMock()
|
||||
seg_query.filter_by.return_value = seg_query
|
||||
seg_query.first.side_effect = [bad_segment, good_segment, good_segment]
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is summary_module.DocumentSegmentSummary:
|
||||
return summary_query
|
||||
return seg_query
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -895,11 +800,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
|
||||
summary = _summary_record(summary_content="sum", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = [summary]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = [summary]
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
@ -918,11 +819,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch:
|
||||
def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dataset = _dataset()
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.filter.return_value = query
|
||||
query.all.return_value = []
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -946,10 +843,7 @@ def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch:
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
@ -971,10 +865,7 @@ def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatc
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -996,10 +887,7 @@ def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: py
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1015,10 +903,7 @@ def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch:
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
|
||||
vector_instance = MagicMock()
|
||||
monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance))
|
||||
@ -1044,10 +929,7 @@ def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: py
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1073,10 +955,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1095,10 +974,7 @@ def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.Monke
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1122,10 +998,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
|
||||
record = _summary_record(summary_content="old", node_id="n1")
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = record
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = record
|
||||
session.flush.side_effect = RuntimeError("flush boom")
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -1143,25 +1016,9 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk
|
||||
def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
record = _summary_record(summary_content="sum", node_id="n1")
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = record
|
||||
session.scalars.return_value.all.return_value = [record]
|
||||
|
||||
q1 = MagicMock()
|
||||
q1.where.return_value = q1
|
||||
q1.first.return_value = record
|
||||
|
||||
q2 = MagicMock()
|
||||
q2.filter.return_value = q2
|
||||
q2.all.return_value = [record]
|
||||
|
||||
def query_side_effect(model: object) -> MagicMock:
|
||||
if model is summary_module.DocumentSegmentSummary:
|
||||
# first call used by get_segment_summary, second by get_document_summaries
|
||||
if not hasattr(query_side_effect, "_called"):
|
||||
query_side_effect._called = True # type: ignore[attr-defined]
|
||||
return q1
|
||||
return q2
|
||||
return MagicMock()
|
||||
|
||||
session.query.side_effect = query_side_effect
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1178,10 +1035,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
record2 = _summary_record()
|
||||
record2.chunk_id = "seg-2"
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.all.return_value = [record1, record2]
|
||||
session.query.return_value = q
|
||||
session.scalars.return_value.all.return_value = [record1, record2]
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1194,10 +1048,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
|
||||
def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.all.return_value = []
|
||||
session.query.return_value = q
|
||||
session.scalars.return_value.all.return_value = []
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1212,10 +1063,7 @@ def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.Monk
|
||||
|
||||
def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = MagicMock()
|
||||
q = MagicMock()
|
||||
q.where.return_value = q
|
||||
q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
|
||||
session.query.return_value = q
|
||||
session.execute.return_value.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")]
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
@ -1237,10 +1085,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro
|
||||
segment = _segment()
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.filter_by.return_value = query
|
||||
query.first.return_value = None
|
||||
session.query.return_value = query
|
||||
session.scalar.return_value = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
@ -1267,10 +1112,7 @@ def test_get_segments_summaries_empty_list() -> None:
|
||||
def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seg_row = SimpleNamespace(id="seg-1", document_id="doc-1")
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.all.return_value = [SimpleNamespace(id="seg-1")]
|
||||
session.query.return_value = query
|
||||
session.scalars.return_value.all.return_value = ["seg-1"] # get_document_summary_index_status returns IDs
|
||||
|
||||
create_session_mock = MagicMock(return_value=_SessionContext(session))
|
||||
monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock))
|
||||
@ -1283,11 +1125,8 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt
|
||||
assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING"
|
||||
|
||||
# Multiple docs
|
||||
query2 = MagicMock()
|
||||
query2.where.return_value = query2
|
||||
query2.all.return_value = [seg_row]
|
||||
session2 = MagicMock()
|
||||
session2.query.return_value = query2
|
||||
session2.execute.return_value.all.return_value = [seg_row] # get_documents_summary_index_status uses execute
|
||||
monkeypatch.setattr(
|
||||
summary_module,
|
||||
"session_factory",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user