mirror of
https://github.com/langgenius/dify.git
synced 2026-06-23 04:11:09 +08:00
refactor: accept db.session explicitly in FeedbackService (#37694)
This commit is contained in:
parent
adfd820220
commit
75d50455d6
@ -338,6 +338,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
|
||||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
db.session(),
|
||||
app_id=app_model.id,
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
|
||||
@ -5,8 +5,8 @@ from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import Account, App, Conversation, Message, MessageFeedback
|
||||
|
||||
@ -14,6 +14,7 @@ from models.model import Account, App, Conversation, Message, MessageFeedback
|
||||
class FeedbackService:
|
||||
@staticmethod
|
||||
def export_feedbacks(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
from_source: str | None = None,
|
||||
rating: str | None = None,
|
||||
@ -81,7 +82,7 @@ class FeedbackService:
|
||||
stmt = stmt.order_by(MessageFeedback.created_at.desc())
|
||||
|
||||
# Execute query
|
||||
results = db.session.execute(stmt).all()
|
||||
results = session.execute(stmt).all()
|
||||
|
||||
# Prepare data for export
|
||||
export_data = []
|
||||
|
||||
@ -289,6 +289,7 @@ class TestFeedbackExportApi:
|
||||
|
||||
# Verify service was called with correct parameters
|
||||
mock_export_feedbacks.assert_called_once_with(
|
||||
mock.ANY,
|
||||
app_id=mock_app_model.id,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
|
||||
@ -7,7 +7,6 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
@ -23,11 +22,9 @@ class TestFeedbackService:
|
||||
"""Test FeedbackService methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Mock database session."""
|
||||
mock_session = mock.Mock()
|
||||
monkeypatch.setattr(db, "session", mock_session)
|
||||
return mock_session
|
||||
def mock_db_session(self):
|
||||
"""Mock database session passed explicitly to the service."""
|
||||
return mock.Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
@ -100,7 +97,7 @@ class TestFeedbackService:
|
||||
)
|
||||
|
||||
# Test CSV export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Verify response structure
|
||||
assert hasattr(result, "headers")
|
||||
@ -131,7 +128,7 @@ class TestFeedbackService:
|
||||
)
|
||||
|
||||
# Test JSON export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Verify response structure
|
||||
assert hasattr(result, "headers")
|
||||
@ -161,6 +158,7 @@ class TestFeedbackService:
|
||||
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
mock_db_session,
|
||||
app_id=sample_data["app"].id,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
@ -177,7 +175,7 @@ class TestFeedbackService:
|
||||
"""Test exporting feedback when no data exists."""
|
||||
mock_db_session.execute.return_value = _execute_result([])
|
||||
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Should return an empty CSV with headers only
|
||||
assert hasattr(result, "headers")
|
||||
@ -195,17 +193,22 @@ class TestFeedbackService:
|
||||
|
||||
# Test with invalid start_date
|
||||
with pytest.raises(ValueError, match="Invalid start_date format"):
|
||||
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, start_date="invalid-date-format")
|
||||
FeedbackService.export_feedbacks(
|
||||
mock_db_session, app_id=sample_data["app"].id, start_date="invalid-date-format"
|
||||
)
|
||||
|
||||
# Test with invalid end_date
|
||||
with pytest.raises(ValueError, match="Invalid end_date format"):
|
||||
FeedbackService.export_feedbacks(app_id=sample_data["app"].id, end_date="invalid-date-format")
|
||||
FeedbackService.export_feedbacks(
|
||||
mock_db_session, app_id=sample_data["app"].id, end_date="invalid-date-format"
|
||||
)
|
||||
|
||||
def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data):
|
||||
"""Test exporting feedback with unsupported format."""
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported format"):
|
||||
FeedbackService.export_feedbacks(
|
||||
mock_db_session,
|
||||
app_id=sample_data["app"].id,
|
||||
format_type="xml", # Unsupported format
|
||||
)
|
||||
@ -236,7 +239,7 @@ class TestFeedbackService:
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Check JSON content
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
@ -287,7 +290,7 @@ class TestFeedbackService:
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv")
|
||||
result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv")
|
||||
|
||||
# Check that unicode content is preserved
|
||||
csv_content = result.get_data(as_text=True)
|
||||
@ -317,7 +320,7 @@ class TestFeedbackService:
|
||||
)
|
||||
|
||||
# Test export
|
||||
result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json")
|
||||
result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json")
|
||||
|
||||
# Check JSON content for emoji ratings
|
||||
json_content = json.loads(result.get_data(as_text=True))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user