From dcff1870d5a78246a939e9bf26e1c811cbaa7be5 Mon Sep 17 00:00:00 2001 From: Rohit Gahlawat <283466839+Rohit-Gahlawat@users.noreply.github.com> Date: Sat, 20 Jun 2026 18:05:06 +0530 Subject: [PATCH] refactor: accept db.session explicitly in SavedMessageService (#37682) --- .../console/explore/saved_message.py | 6 ++-- api/controllers/web/saved_message.py | 9 ++++-- api/services/saved_message_service.py | 22 +++++++------- .../services/test_saved_message_service.py | 30 +++++++++++-------- .../console/explore/test_saved_message.py | 6 ++-- 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 3e8f1ce9083..ce43ff18c93 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -11,6 +11,7 @@ from controllers.console.app.error import AppUnavailableError from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import with_current_user +from extensions.ext_database import db from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from models import Account @@ -37,6 +38,7 @@ class SavedMessageListApi(InstalledAppResource): args = SavedMessageListQuery.model_validate(request.args.to_dict()) pagination = SavedMessageService.pagination_by_last_id( + db.session(), app_model, current_user, str(args.last_id) if args.last_id else None, @@ -63,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource): payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {}) try: - SavedMessageService.save(app_model, current_user, str(payload.message_id)) + SavedMessageService.save(db.session(), app_model, current_user, str(payload.message_id)) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -86,6 +88,6 @@ class SavedMessageApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, current_user, message_id_str) + SavedMessageService.delete(db.session(), app_model, current_user, message_id_str) return "", 204 diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e3baa028e50..6e59a85e2b0 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -9,6 +9,7 @@ from controllers.common.schema import query_params_from_model, register_response from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource +from extensions.ext_database import db from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem from models.model import App, EndUser @@ -42,7 +43,9 @@ class SavedMessageListApi(WebApiResource): raw_args = request.args.to_dict() query = SavedMessageListQuery.model_validate(raw_args) - pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit) + pagination = SavedMessageService.pagination_by_last_id( + db.session(), app_model, end_user, query.last_id, query.limit + ) adapter = TypeAdapter(SavedMessageItem) items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] return SavedMessageInfiniteScrollPagination( @@ -77,7 +80,7 @@ class SavedMessageListApi(WebApiResource): payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {}) try: - SavedMessageService.save(app_model, end_user, payload.message_id) + SavedMessageService.save(db.session(), app_model, end_user, payload.message_id) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -105,6 +108,6 @@ class SavedMessageApi(WebApiResource): if app_model.mode != "completion": raise NotCompletionAppError() - SavedMessageService.delete(app_model, end_user, message_id_str) + SavedMessageService.delete(db.session(), app_model, end_user, message_id_str) return "", 204 diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 90f01377123..9a65429748e 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -1,6 +1,6 @@ from sqlalchemy import select +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import CreatorUserRole @@ -12,11 +12,11 @@ from services.message_service import MessageService class SavedMessageService: @classmethod def pagination_by_last_id( - cls, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int + cls, session: Session, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int ) -> InfiniteScrollPagination: if not user: raise ValueError("User is required") - saved_messages = db.session.scalars( + saved_messages = session.scalars( select(SavedMessage) .where( SavedMessage.app_id == app_model.id, @@ -32,10 +32,10 @@ class SavedMessageService: ) @classmethod - def save(cls, app_model: App, user: Account | EndUser | None, message_id: str): + def save(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str): if not user: return - saved_message = db.session.scalar( + saved_message = session.scalar( select(SavedMessage) .where( SavedMessage.app_id == app_model.id, @@ -58,14 +58,14 @@ class SavedMessageService: created_by=user.id, ) - db.session.add(saved_message) - db.session.commit() + session.add(saved_message) + session.commit() @classmethod - def delete(cls, app_model: App, user: Account | EndUser | None, message_id: str): + def delete(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str): if not user: return - saved_message = db.session.scalar( + saved_message = session.scalar( select(SavedMessage) .where( SavedMessage.app_id == app_model.id, @@ -79,5 +79,5 @@ class SavedMessageService: if not saved_message: return - db.session.delete(saved_message) - db.session.commit() + session.delete(saved_message) + session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index ac434021fc8..ad85ac67bc5 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -220,7 +220,9 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination # Act: Execute the method under test - result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10) + result = SavedMessageService.pagination_by_last_id( + db_session_with_containers, app_model=app, user=account, last_id=None, limit=10 + ) # Assert: Verify the expected outcomes assert result is not None @@ -294,7 +296,7 @@ class TestSavedMessageService: # Act: Execute the method under test result = SavedMessageService.pagination_by_last_id( - app_model=app, user=end_user, last_id="test_last_id", limit=5 + db_session_with_containers, app_model=app, user=end_user, last_id="test_last_id", limit=5 ) # Assert: Verify the expected outcomes @@ -344,7 +346,7 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].get_message.return_value = message # Act: Execute the method under test - SavedMessageService.save(app_model=app, user=account, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id) # Assert: Verify the expected outcomes # Check if saved message was created in database @@ -393,7 +395,9 @@ class TestSavedMessageService: # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: - SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + SavedMessageService.pagination_by_last_id( + db_session_with_containers, app_model=app, user=None, last_id=None, limit=10 + ) assert "User is required" in str(exc_info.value) @@ -412,7 +416,7 @@ class TestSavedMessageService: message = self._create_test_message(db_session_with_containers, app, account) # Act: Execute the method under test with None user - result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) + result = SavedMessageService.save(db_session_with_containers, app_model=app, user=None, message_id=message.id) # Assert: Verify the expected outcomes assert result is None @@ -471,7 +475,7 @@ class TestSavedMessageService: ) # Act: Execute the method under test - SavedMessageService.delete(app_model=app, user=account, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=message.id) # Assert: Verify the expected outcomes # Check if saved message was deleted from database @@ -501,7 +505,7 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].get_message.return_value = message - SavedMessageService.save(app_model=app, user=end_user, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=end_user, message_id=message.id) saved = ( db_session_with_containers.query(SavedMessage) @@ -522,9 +526,9 @@ class TestSavedMessageService: mock_external_service_dependencies["message_service"].get_message.return_value = message # Save once - SavedMessageService.save(app_model=app, user=account, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id) # Save again - SavedMessageService.save(app_model=app, user=account, message_id=message.id) + SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id) count = ( db_session_with_containers.query(SavedMessage) @@ -547,7 +551,7 @@ class TestSavedMessageService: db_session_with_containers.add(saved) db_session_with_containers.commit() - SavedMessageService.delete(app_model=app, user=None, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=None, message_id=message.id) # Should still exist assert ( @@ -566,7 +570,7 @@ class TestSavedMessageService: # Should not raise — use a valid UUID that doesn't exist in DB from uuid import uuid4 - SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4())) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=str(uuid4())) def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """Test deleting a saved message for an EndUser.""" @@ -580,7 +584,7 @@ class TestSavedMessageService: db_session_with_containers.add(saved) db_session_with_containers.commit() - SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=end_user, message_id=message.id) assert ( db_session_with_containers.query(SavedMessage) @@ -610,7 +614,7 @@ class TestSavedMessageService: db_session_with_containers.commit() # Delete only account1's saved message - SavedMessageService.delete(app_model=app, user=account1, message_id=message.id) + SavedMessageService.delete(db_session_with_containers, app_model=app, user=account1, message_id=message.id) # Account's saved message should be gone assert ( diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py index f210d0d5d04..ae05b8f6a0e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -63,7 +63,7 @@ class TestSavedMessageListApi: result = method(api, current_user, installed_app) pagination_mock.assert_called_once() - assert pagination_mock.call_args.args[1] is current_user + assert pagination_mock.call_args.args[2] is current_user assert result["limit"] == 20 assert result["has_more"] is False assert len(result["data"]) == 2 @@ -96,7 +96,7 @@ class TestSavedMessageListApi: result = method(api, current_user, installed_app) save_mock.assert_called_once() - assert save_mock.call_args.args[1] is current_user + assert save_mock.call_args.args[2] is current_user assert result == {"result": "success"} def test_post_message_not_exists(self, app: Flask, payload_patch): @@ -136,7 +136,7 @@ class TestSavedMessageApi: result, status = method(api, current_user, installed_app, str(uuid4())) delete_mock.assert_called_once() - assert delete_mock.call_args.args[1] is current_user + assert delete_mock.call_args.args[2] is current_user assert status == 204 assert result == ""