diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 43b41903f60..b66c97c274c 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -30,6 +30,7 @@ from controllers.console.wraps import ( setup_required, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode @@ -142,6 +143,7 @@ class ChatMessageTextApi(Resource): response = AudioService.transcript_tts( app_model=app_model, + session=db.session, text=payload.text, voice=payload.voice, message_id=payload.message_id, diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 756dfe84f6c..c2104ccfc61 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -20,6 +20,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from models.model import InstalledApp from services.audio_service import AudioService @@ -99,7 +100,13 @@ class ChatTextApi(InstalledAppResource): text = payload.text voice = payload.voice - response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + response = AudioService.transcript_tts( + app_model=app_model, + session=db.session, + text=text, + voice=voice, + message_id=message_id, + ) return response except services.errors.app_model_config.AppModelConfigBrokenError: logger.exception("App model config broken.") diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index ad98dd303fb..6aef9129780 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -419,7 +419,13 @@ class TrialChatTextApi(TrialAppResource): app_id = app_model.id user_id = current_user.id - response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + response = AudioService.transcript_tts( + app_model=app_model, + session=db.session, + text=text, + voice=voice, + message_id=message_id, + ) RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return response except services.errors.app_model_config.AppModelConfigBrokenError: diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 2b5a9ba83a1..59ed4b4a4b1 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -23,6 +23,7 @@ from controllers.service_api.app.error import ( from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService @@ -177,7 +178,12 @@ class TextApi(Resource): text = payload.text voice = payload.voice response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id + app_model=app_model, + session=db.session, + text=text, + voice=voice, + end_user=end_user.external_user_id, + message_id=message_id, ) return response diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index c762c914861..801c1f5a629 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -22,6 +22,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App, EndUser @@ -130,7 +131,12 @@ class TextApi(WebApiResource): text = payload.text voice = payload.voice response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id + app_model=app_model, + session=db.session, + text=text, + voice=voice, + end_user=end_user.external_user_id, + message_id=message_id, ) return response diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a9024eb3bdd..14c5c0111e5 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,11 +5,11 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from sqlalchemy.orm import Session, scoped_session from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message @@ -77,6 +77,8 @@ class AudioService: def transcript_tts( cls, app_model: App, + *, + session: Session | scoped_session, text: str | None = None, voice: str | None = None, end_user: str | None = None, @@ -87,7 +89,7 @@ class AudioService: if voice is None: if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft: - workflow = WorkflowService().get_draft_workflow(app_model=app_model) + workflow = WorkflowService().get_draft_workflow(app_model=app_model, session=session) else: workflow = app_model.workflow if ( @@ -132,7 +134,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.get(Message, message_id) + message = session.get(Message, message_id) if message is None: return None if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 9f8e4b83093..262ccc18f83 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast from sqlalchemy import exists, select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -142,7 +142,7 @@ class WorkflowService: return db.session.execute(stmt).scalar_one() def get_draft_workflow( - self, app_model: App, workflow_id: str | None = None, session: Session | None = None + self, app_model: App, workflow_id: str | None = None, session: Session | scoped_session | None = None ) -> Workflow | None: """ Get draft workflow @@ -169,7 +169,7 @@ class WorkflowService: return workflow def get_published_workflow_by_id( - self, app_model: App, workflow_id: str, session: Session | None = None + self, app_model: App, workflow_id: str, session: Session | scoped_session | None = None ) -> Workflow | None: """ fetch published workflow by workflow_id diff --git a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py index 2593b53fe84..c9cf60bcfb1 100644 --- a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py +++ b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py @@ -158,6 +158,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager): result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=message.id, voice="en-US-Neural", ) @@ -174,6 +175,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id="invalid-uuid", ) @@ -185,6 +187,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=str(uuid4()), ) @@ -205,6 +208,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=message.id, ) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 1cfe152c864..52d050ff55a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -176,6 +176,7 @@ class TestAudioServiceMockedBehavior: result = AudioService.transcript_tts( app_model=mock_app, + session=Mock(), text="Hello world", voice="nova", end_user="user_123", diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d148974f87..788a47c5c31 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -398,6 +398,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Hello world", voice="en-US-Neural", end_user="user-123", @@ -432,6 +433,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Test", ) @@ -465,6 +467,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Test", ) @@ -496,17 +499,52 @@ class TestAudioServiceTTS: mock_model_instance = MagicMock() mock_model_instance.invoke_tts.return_value = b"draft audio" mock_model_manager.get_default_model_instance.return_value = mock_model_instance + session = MagicMock() # Act result = AudioService.transcript_tts( app_model=app, + session=session, text="Draft test", is_draft=True, ) # Assert assert result == b"draft audio" - mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app) + mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app, session=session) + + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) + def test_transcript_tts_message_id_uses_provided_session( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): + """Test TTS message lookup uses the injected session.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.CHAT) + message_id = "00000000-0000-0000-0000-000000000001" + message = factory.create_message_mock(message_id=message_id, answer="Message answer") + session = MagicMock() + session.get.return_value = message + + mock_model_manager = mock_model_manager_class.return_value + mock_model_instance = MagicMock() + mock_model_instance.invoke_tts.return_value = b"message audio" + mock_model_manager.get_default_model_instance.return_value = mock_model_instance + + # Act + result = AudioService.transcript_tts( + app_model=app, + session=session, + message_id=message_id, + voice="message-voice", + ) + + # Assert + assert result == b"message audio" + session.get.assert_called_once_with(Message, message_id) + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="Message answer", + voice="message-voice", + ) def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory): """Test that TTS raises error when text is missing.""" @@ -515,7 +553,7 @@ class TestAudioServiceTTS: # Act & Assert with pytest.raises(ValueError, match="Text is required"): - AudioService.transcript_tts(app_model=app, text=None) + AudioService.transcript_tts(app_model=app, session=MagicMock(), text=None) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available( @@ -539,7 +577,7 @@ class TestAudioServiceTTS: # Act & Assert with pytest.raises(ValueError, match="Sorry, no voice available"): - AudioService.transcript_tts(app_model=app, text="Test") + AudioService.transcript_tts(app_model=app, session=MagicMock(), text="Test") class TestAudioServiceTTSVoices: