test: migrate test_messages_clean_service to SQLAlchemy 2.0 select() API (#34984)

This commit is contained in:
dataCenter430 2026-04-11 22:21:07 -07:00 committed by GitHub
parent 7ba70869aa
commit 64920ef648
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import math
import uuid
import pytest
from sqlalchemy import delete
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from models import Tenant
@ -210,7 +210,7 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
assert remaining == len(all_ids)
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
@ -231,7 +231,7 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == len(all_ids)
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
assert remaining == 0
def test_start_from_filters_correctly(self, seed_messages):
@ -254,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
with session_factory.create_session() as session:
all_ids = list(msg_ids.values())
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
remaining_ids = set(session.scalars(select(Message.id).where(Message.id.in_(all_ids))).all())
assert msg_ids["old"] not in remaining_ids
assert msg_ids["very_old"] in remaining_ids
@ -282,7 +282,7 @@ class TestMessagesCleanServiceIntegration:
assert stats["batches"] >= expected_batches
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(msg_ids)))
assert remaining == 0
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
@ -319,9 +319,17 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
assert session.query(Message).where(Message.id == msg_id).count() == 0
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
assert session.scalar(select(func.count()).select_from(Message).where(Message.id == msg_id)) == 0
assert (
session.scalar(select(func.count()).select_from(MessageFeedback).where(MessageFeedback.id == fb_id))
== 0
)
assert (
session.scalar(
select(func.count()).select_from(MessageAnnotation).where(MessageAnnotation.id == ann_id)
)
== 0
)
def test_factory_from_time_range_validation(self):
with pytest.raises(ValueError, match="start_from"):