refactor: accept db.session explicitly in SavedMessageService (#37682)

This commit is contained in:
Rohit Gahlawat 2026-06-20 18:05:06 +05:30 committed by GitHub
parent b3e724dce1
commit dcff1870d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 41 additions and 32 deletions

View File

@ -11,6 +11,7 @@ from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from extensions.ext_database import db
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from models import Account
@ -37,6 +38,7 @@ class SavedMessageListApi(InstalledAppResource):
args = SavedMessageListQuery.model_validate(request.args.to_dict())
pagination = SavedMessageService.pagination_by_last_id(
db.session(),
app_model,
current_user,
str(args.last_id) if args.last_id else None,
@ -63,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
try:
SavedMessageService.save(app_model, current_user, str(payload.message_id))
SavedMessageService.save(db.session(), app_model, current_user, str(payload.message_id))
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -86,6 +88,6 @@ class SavedMessageApi(InstalledAppResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
SavedMessageService.delete(app_model, current_user, message_id_str)
SavedMessageService.delete(db.session(), app_model, current_user, message_id_str)
return "", 204

View File

@ -9,6 +9,7 @@ from controllers.common.schema import query_params_from_model, register_response
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from models.model import App, EndUser
@ -42,7 +43,9 @@ class SavedMessageListApi(WebApiResource):
raw_args = request.args.to_dict()
query = SavedMessageListQuery.model_validate(raw_args)
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
pagination = SavedMessageService.pagination_by_last_id(
db.session(), app_model, end_user, query.last_id, query.limit
)
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
@ -77,7 +80,7 @@ class SavedMessageListApi(WebApiResource):
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
try:
SavedMessageService.save(app_model, end_user, payload.message_id)
SavedMessageService.save(db.session(), app_model, end_user, payload.message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
@ -105,6 +108,6 @@ class SavedMessageApi(WebApiResource):
if app_model.mode != "completion":
raise NotCompletionAppError()
SavedMessageService.delete(app_model, end_user, message_id_str)
SavedMessageService.delete(db.session(), app_model, end_user, message_id_str)
return "", 204

View File

@ -1,6 +1,6 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import CreatorUserRole
@ -12,11 +12,11 @@ from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(
cls, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int
cls, session: Session, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
saved_messages = db.session.scalars(
saved_messages = session.scalars(
select(SavedMessage)
.where(
SavedMessage.app_id == app_model.id,
@ -32,10 +32,10 @@ class SavedMessageService:
)
@classmethod
def save(cls, app_model: App, user: Account | EndUser | None, message_id: str):
def save(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str):
if not user:
return
saved_message = db.session.scalar(
saved_message = session.scalar(
select(SavedMessage)
.where(
SavedMessage.app_id == app_model.id,
@ -58,14 +58,14 @@ class SavedMessageService:
created_by=user.id,
)
db.session.add(saved_message)
db.session.commit()
session.add(saved_message)
session.commit()
@classmethod
def delete(cls, app_model: App, user: Account | EndUser | None, message_id: str):
def delete(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str):
if not user:
return
saved_message = db.session.scalar(
saved_message = session.scalar(
select(SavedMessage)
.where(
SavedMessage.app_id == app_model.id,
@ -79,5 +79,5 @@ class SavedMessageService:
if not saved_message:
return
db.session.delete(saved_message)
db.session.commit()
session.delete(saved_message)
session.commit()

View File

@ -220,7 +220,9 @@ class TestSavedMessageService:
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
# Act: Execute the method under test
result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10)
result = SavedMessageService.pagination_by_last_id(
db_session_with_containers, app_model=app, user=account, last_id=None, limit=10
)
# Assert: Verify the expected outcomes
assert result is not None
@ -294,7 +296,7 @@ class TestSavedMessageService:
# Act: Execute the method under test
result = SavedMessageService.pagination_by_last_id(
app_model=app, user=end_user, last_id="test_last_id", limit=5
db_session_with_containers, app_model=app, user=end_user, last_id="test_last_id", limit=5
)
# Assert: Verify the expected outcomes
@ -344,7 +346,7 @@ class TestSavedMessageService:
mock_external_service_dependencies["message_service"].get_message.return_value = message
# Act: Execute the method under test
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was created in database
@ -393,7 +395,9 @@ class TestSavedMessageService:
# Act & Assert: Verify proper error handling
with pytest.raises(ValueError) as exc_info:
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
SavedMessageService.pagination_by_last_id(
db_session_with_containers, app_model=app, user=None, last_id=None, limit=10
)
assert "User is required" in str(exc_info.value)
@ -412,7 +416,7 @@ class TestSavedMessageService:
message = self._create_test_message(db_session_with_containers, app, account)
# Act: Execute the method under test with None user
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
result = SavedMessageService.save(db_session_with_containers, app_model=app, user=None, message_id=message.id)
# Assert: Verify the expected outcomes
assert result is None
@ -471,7 +475,7 @@ class TestSavedMessageService:
)
# Act: Execute the method under test
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=message.id)
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
@ -501,7 +505,7 @@ class TestSavedMessageService:
mock_external_service_dependencies["message_service"].get_message.return_value = message
SavedMessageService.save(app_model=app, user=end_user, message_id=message.id)
SavedMessageService.save(db_session_with_containers, app_model=app, user=end_user, message_id=message.id)
saved = (
db_session_with_containers.query(SavedMessage)
@ -522,9 +526,9 @@ class TestSavedMessageService:
mock_external_service_dependencies["message_service"].get_message.return_value = message
# Save once
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id)
# Save again
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id)
count = (
db_session_with_containers.query(SavedMessage)
@ -547,7 +551,7 @@ class TestSavedMessageService:
db_session_with_containers.add(saved)
db_session_with_containers.commit()
SavedMessageService.delete(app_model=app, user=None, message_id=message.id)
SavedMessageService.delete(db_session_with_containers, app_model=app, user=None, message_id=message.id)
# Should still exist
assert (
@ -566,7 +570,7 @@ class TestSavedMessageService:
# Should not raise — use a valid UUID that doesn't exist in DB
from uuid import uuid4
SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4()))
SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=str(uuid4()))
def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""Test deleting a saved message for an EndUser."""
@ -580,7 +584,7 @@ class TestSavedMessageService:
db_session_with_containers.add(saved)
db_session_with_containers.commit()
SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id)
SavedMessageService.delete(db_session_with_containers, app_model=app, user=end_user, message_id=message.id)
assert (
db_session_with_containers.query(SavedMessage)
@ -610,7 +614,7 @@ class TestSavedMessageService:
db_session_with_containers.commit()
# Delete only account1's saved message
SavedMessageService.delete(app_model=app, user=account1, message_id=message.id)
SavedMessageService.delete(db_session_with_containers, app_model=app, user=account1, message_id=message.id)
# Account's saved message should be gone
assert (

View File

@ -63,7 +63,7 @@ class TestSavedMessageListApi:
result = method(api, current_user, installed_app)
pagination_mock.assert_called_once()
assert pagination_mock.call_args.args[1] is current_user
assert pagination_mock.call_args.args[2] is current_user
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
@ -96,7 +96,7 @@ class TestSavedMessageListApi:
result = method(api, current_user, installed_app)
save_mock.assert_called_once()
assert save_mock.call_args.args[1] is current_user
assert save_mock.call_args.args[2] is current_user
assert result == {"result": "success"}
def test_post_message_not_exists(self, app: Flask, payload_patch):
@ -136,7 +136,7 @@ class TestSavedMessageApi:
result, status = method(api, current_user, installed_app, str(uuid4()))
delete_mock.assert_called_once()
assert delete_mock.call_args.args[1] is current_user
assert delete_mock.call_args.args[2] is current_user
assert status == 204
assert result == ""