mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
test: migrate test_messages_clean_service to SQLAlchemy 2.0 select() API (#34984)
This commit is contained in:
parent
7ba70869aa
commit
64920ef648
@ -3,7 +3,7 @@ import math
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
@ -210,7 +210,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
|
||||
assert remaining == len(all_ids)
|
||||
|
||||
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
|
||||
@ -231,7 +231,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == len(all_ids)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
|
||||
assert remaining == 0
|
||||
|
||||
def test_start_from_filters_correctly(self, seed_messages):
|
||||
@ -254,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
all_ids = list(msg_ids.values())
|
||||
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
|
||||
remaining_ids = set(session.scalars(select(Message.id).where(Message.id.in_(all_ids))).all())
|
||||
|
||||
assert msg_ids["old"] not in remaining_ids
|
||||
assert msg_ids["very_old"] in remaining_ids
|
||||
@ -282,7 +282,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["batches"] >= expected_batches
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
|
||||
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(msg_ids)))
|
||||
assert remaining == 0
|
||||
|
||||
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
|
||||
@ -319,9 +319,17 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(Message).where(Message.id == msg_id).count() == 0
|
||||
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
|
||||
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
|
||||
assert session.scalar(select(func.count()).select_from(Message).where(Message.id == msg_id)) == 0
|
||||
assert (
|
||||
session.scalar(select(func.count()).select_from(MessageFeedback).where(MessageFeedback.id == fb_id))
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
session.scalar(
|
||||
select(func.count()).select_from(MessageAnnotation).where(MessageAnnotation.id == ann_id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_factory_from_time_range_validation(self):
|
||||
with pytest.raises(ValueError, match="start_from"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user