mirror of
https://github.com/langgenius/dify.git
synced 2026-06-22 19:21:13 +08:00
refactor: accept db.session explicitly in SavedMessageService (#37682)
This commit is contained in:
parent
b3e724dce1
commit
dcff1870d5
@ -11,6 +11,7 @@ from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from models import Account
|
||||
@ -37,6 +38,7 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
args = SavedMessageListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
pagination = SavedMessageService.pagination_by_last_id(
|
||||
db.session(),
|
||||
app_model,
|
||||
current_user,
|
||||
str(args.last_id) if args.last_id else None,
|
||||
@ -63,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
payload = SavedMessageCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, current_user, str(payload.message_id))
|
||||
SavedMessageService.save(db.session(), app_model, current_user, str(payload.message_id))
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -86,6 +88,6 @@ class SavedMessageApi(InstalledAppResource):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id_str)
|
||||
SavedMessageService.delete(db.session(), app_model, current_user, message_id_str)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -9,6 +9,7 @@ from controllers.common.schema import query_params_from_model, register_response
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from models.model import App, EndUser
|
||||
@ -42,7 +43,9 @@ class SavedMessageListApi(WebApiResource):
|
||||
raw_args = request.args.to_dict()
|
||||
query = SavedMessageListQuery.model_validate(raw_args)
|
||||
|
||||
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, query.last_id, query.limit)
|
||||
pagination = SavedMessageService.pagination_by_last_id(
|
||||
db.session(), app_model, end_user, query.last_id, query.limit
|
||||
)
|
||||
adapter = TypeAdapter(SavedMessageItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return SavedMessageInfiniteScrollPagination(
|
||||
@ -77,7 +80,7 @@ class SavedMessageListApi(WebApiResource):
|
||||
payload = SavedMessageCreatePayload.model_validate(web_ns.payload or {})
|
||||
|
||||
try:
|
||||
SavedMessageService.save(app_model, end_user, payload.message_id)
|
||||
SavedMessageService.save(db.session(), app_model, end_user, payload.message_id)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -105,6 +108,6 @@ class SavedMessageApi(WebApiResource):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
SavedMessageService.delete(app_model, end_user, message_id_str)
|
||||
SavedMessageService.delete(db.session(), app_model, end_user, message_id_str)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
@ -12,11 +12,11 @@ from services.message_service import MessageService
|
||||
class SavedMessageService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(
|
||||
cls, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int
|
||||
cls, session: Session, app_model: App, user: Account | EndUser | None, last_id: str | None, limit: int
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
raise ValueError("User is required")
|
||||
saved_messages = db.session.scalars(
|
||||
saved_messages = session.scalars(
|
||||
select(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
@ -32,10 +32,10 @@ class SavedMessageService:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save(cls, app_model: App, user: Account | EndUser | None, message_id: str):
|
||||
def save(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = db.session.scalar(
|
||||
saved_message = session.scalar(
|
||||
select(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
@ -58,14 +58,14 @@ class SavedMessageService:
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
session.add(saved_message)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, user: Account | EndUser | None, message_id: str):
|
||||
def delete(cls, session: Session, app_model: App, user: Account | EndUser | None, message_id: str):
|
||||
if not user:
|
||||
return
|
||||
saved_message = db.session.scalar(
|
||||
saved_message = session.scalar(
|
||||
select(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
@ -79,5 +79,5 @@ class SavedMessageService:
|
||||
if not saved_message:
|
||||
return
|
||||
|
||||
db.session.delete(saved_message)
|
||||
db.session.commit()
|
||||
session.delete(saved_message)
|
||||
session.commit()
|
||||
|
||||
@ -220,7 +220,9 @@ class TestSavedMessageService:
|
||||
mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10)
|
||||
result = SavedMessageService.pagination_by_last_id(
|
||||
db_session_with_containers, app_model=app, user=account, last_id=None, limit=10
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
@ -294,7 +296,7 @@ class TestSavedMessageService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = SavedMessageService.pagination_by_last_id(
|
||||
app_model=app, user=end_user, last_id="test_last_id", limit=5
|
||||
db_session_with_containers, app_model=app, user=end_user, last_id="test_last_id", limit=5
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -344,7 +346,7 @@ class TestSavedMessageService:
|
||||
mock_external_service_dependencies["message_service"].get_message.return_value = message
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
|
||||
SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was created in database
|
||||
@ -393,7 +395,9 @@ class TestSavedMessageService:
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10)
|
||||
SavedMessageService.pagination_by_last_id(
|
||||
db_session_with_containers, app_model=app, user=None, last_id=None, limit=10
|
||||
)
|
||||
|
||||
assert "User is required" in str(exc_info.value)
|
||||
|
||||
@ -412,7 +416,7 @@ class TestSavedMessageService:
|
||||
message = self._create_test_message(db_session_with_containers, app, account)
|
||||
|
||||
# Act: Execute the method under test with None user
|
||||
result = SavedMessageService.save(app_model=app, user=None, message_id=message.id)
|
||||
result = SavedMessageService.save(db_session_with_containers, app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is None
|
||||
@ -471,7 +475,7 @@ class TestSavedMessageService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=message.id)
|
||||
SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=message.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
@ -501,7 +505,7 @@ class TestSavedMessageService:
|
||||
|
||||
mock_external_service_dependencies["message_service"].get_message.return_value = message
|
||||
|
||||
SavedMessageService.save(app_model=app, user=end_user, message_id=message.id)
|
||||
SavedMessageService.save(db_session_with_containers, app_model=app, user=end_user, message_id=message.id)
|
||||
|
||||
saved = (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
@ -522,9 +526,9 @@ class TestSavedMessageService:
|
||||
mock_external_service_dependencies["message_service"].get_message.return_value = message
|
||||
|
||||
# Save once
|
||||
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
|
||||
SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id)
|
||||
# Save again
|
||||
SavedMessageService.save(app_model=app, user=account, message_id=message.id)
|
||||
SavedMessageService.save(db_session_with_containers, app_model=app, user=account, message_id=message.id)
|
||||
|
||||
count = (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
@ -547,7 +551,7 @@ class TestSavedMessageService:
|
||||
db_session_with_containers.add(saved)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
SavedMessageService.delete(app_model=app, user=None, message_id=message.id)
|
||||
SavedMessageService.delete(db_session_with_containers, app_model=app, user=None, message_id=message.id)
|
||||
|
||||
# Should still exist
|
||||
assert (
|
||||
@ -566,7 +570,7 @@ class TestSavedMessageService:
|
||||
# Should not raise — use a valid UUID that doesn't exist in DB
|
||||
from uuid import uuid4
|
||||
|
||||
SavedMessageService.delete(app_model=app, user=account, message_id=str(uuid4()))
|
||||
SavedMessageService.delete(db_session_with_containers, app_model=app, user=account, message_id=str(uuid4()))
|
||||
|
||||
def test_delete_for_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""Test deleting a saved message for an EndUser."""
|
||||
@ -580,7 +584,7 @@ class TestSavedMessageService:
|
||||
db_session_with_containers.add(saved)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
SavedMessageService.delete(app_model=app, user=end_user, message_id=message.id)
|
||||
SavedMessageService.delete(db_session_with_containers, app_model=app, user=end_user, message_id=message.id)
|
||||
|
||||
assert (
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
@ -610,7 +614,7 @@ class TestSavedMessageService:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Delete only account1's saved message
|
||||
SavedMessageService.delete(app_model=app, user=account1, message_id=message.id)
|
||||
SavedMessageService.delete(db_session_with_containers, app_model=app, user=account1, message_id=message.id)
|
||||
|
||||
# Account's saved message should be gone
|
||||
assert (
|
||||
|
||||
@ -63,7 +63,7 @@ class TestSavedMessageListApi:
|
||||
result = method(api, current_user, installed_app)
|
||||
|
||||
pagination_mock.assert_called_once()
|
||||
assert pagination_mock.call_args.args[1] is current_user
|
||||
assert pagination_mock.call_args.args[2] is current_user
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
@ -96,7 +96,7 @@ class TestSavedMessageListApi:
|
||||
result = method(api, current_user, installed_app)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert save_mock.call_args.args[1] is current_user
|
||||
assert save_mock.call_args.args[2] is current_user
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_post_message_not_exists(self, app: Flask, payload_patch):
|
||||
@ -136,7 +136,7 @@ class TestSavedMessageApi:
|
||||
result, status = method(api, current_user, installed_app, str(uuid4()))
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
assert delete_mock.call_args.args[1] is current_user
|
||||
assert delete_mock.call_args.args[2] is current_user
|
||||
assert status == 204
|
||||
assert result == ""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user