refactor: inject session into audio TTS (#37849)

This commit is contained in:
Myshkin451 2026-06-24 12:49:48 +08:00 committed by GitHub
parent 1c1b20aa46
commit 31c08faded
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 85 additions and 13 deletions

View File

@ -30,6 +30,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from models import App, AppMode
@ -142,6 +143,7 @@ class ChatMessageTextApi(Resource):
response = AudioService.transcript_tts(
app_model=app_model,
session=db.session,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,

View File

@ -20,6 +20,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from models.model import InstalledApp
from services.audio_service import AudioService
@ -99,7 +100,13 @@ class ChatTextApi(InstalledAppResource):
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
response = AudioService.transcript_tts(
app_model=app_model,
session=db.session,
text=text,
voice=voice,
message_id=message_id,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")

View File

@ -419,7 +419,13 @@ class TrialChatTextApi(TrialAppResource):
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
response = AudioService.transcript_tts(
app_model=app_model,
session=db.session,
text=text,
voice=voice,
message_id=message_id,
)
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@ -23,6 +23,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from models.model import App, EndUser
from services.audio_service import AudioService
@ -177,7 +178,12 @@ class TextApi(Resource):
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
app_model=app_model,
session=db.session,
text=text,
voice=voice,
end_user=end_user.external_user_id,
message_id=message_id,
)
return response

View File

@ -22,6 +22,7 @@ from controllers.web.error import (
)
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from extensions.ext_database import db
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App, EndUser
@ -130,7 +131,12 @@ class TextApi(WebApiResource):
text = payload.text
voice = payload.voice
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
app_model=app_model,
session=db.session,
text=text,
voice=voice,
end_user=end_user.external_user_id,
message_id=message_id,
)
return response

View File

@ -5,11 +5,11 @@ from collections.abc import Generator
from typing import cast
from flask import Response, stream_with_context
from sqlalchemy.orm import Session, scoped_session
from werkzeug.datastructures import FileStorage
from constants import AUDIO_EXTENSIONS
from core.model_manager import ModelManager
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from models.enums import MessageStatus
from models.model import App, AppMode, Message
@ -77,6 +77,8 @@ class AudioService:
def transcript_tts(
cls,
app_model: App,
*,
session: Session | scoped_session,
text: str | None = None,
voice: str | None = None,
end_user: str | None = None,
@ -87,7 +89,7 @@ class AudioService:
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
workflow = WorkflowService().get_draft_workflow(app_model=app_model, session=session)
else:
workflow = app_model.workflow
if (
@ -132,7 +134,7 @@ class AudioService:
uuid.UUID(message_id)
except ValueError:
return None
message = db.session.get(Message, message_id)
message = session.get(Message, message_id)
if message is None:
return None
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:

View File

@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, cast
from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
@ -142,7 +142,7 @@ class WorkflowService:
return db.session.execute(stmt).scalar_one()
def get_draft_workflow(
self, app_model: App, workflow_id: str | None = None, session: Session | None = None
self, app_model: App, workflow_id: str | None = None, session: Session | scoped_session | None = None
) -> Workflow | None:
"""
Get draft workflow
@ -169,7 +169,7 @@ class WorkflowService:
return workflow
def get_published_workflow_by_id(
self, app_model: App, workflow_id: str, session: Session | None = None
self, app_model: App, workflow_id: str, session: Session | scoped_session | None = None
) -> Workflow | None:
"""
fetch published workflow by workflow_id

View File

@ -158,6 +158,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager):
result = AudioService.transcript_tts(
app_model=app,
session=db_session_with_containers,
message_id=message.id,
voice="en-US-Neural",
)
@ -174,6 +175,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
result = AudioService.transcript_tts(
app_model=app,
session=db_session_with_containers,
message_id="invalid-uuid",
)
@ -185,6 +187,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
result = AudioService.transcript_tts(
app_model=app,
session=db_session_with_containers,
message_id=str(uuid4()),
)
@ -205,6 +208,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
result = AudioService.transcript_tts(
app_model=app,
session=db_session_with_containers,
message_id=message.id,
)

View File

@ -176,6 +176,7 @@ class TestAudioServiceMockedBehavior:
result = AudioService.transcript_tts(
app_model=mock_app,
session=Mock(),
text="Hello world",
voice="nova",
end_user="user_123",

View File

@ -398,6 +398,7 @@ class TestAudioServiceTTS:
# Act
result = AudioService.transcript_tts(
app_model=app,
session=MagicMock(),
text="Hello world",
voice="en-US-Neural",
end_user="user-123",
@ -432,6 +433,7 @@ class TestAudioServiceTTS:
# Act
result = AudioService.transcript_tts(
app_model=app,
session=MagicMock(),
text="Test",
)
@ -465,6 +467,7 @@ class TestAudioServiceTTS:
# Act
result = AudioService.transcript_tts(
app_model=app,
session=MagicMock(),
text="Test",
)
@ -496,17 +499,52 @@ class TestAudioServiceTTS:
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"draft audio"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
session = MagicMock()
# Act
result = AudioService.transcript_tts(
app_model=app,
session=session,
text="Draft test",
is_draft=True,
)
# Assert
assert result == b"draft audio"
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app, session=session)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_message_id_uses_provided_session(
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
):
"""Test TTS message lookup uses the injected session."""
# Arrange
app = factory.create_app_mock(mode=AppMode.CHAT)
message_id = "00000000-0000-0000-0000-000000000001"
message = factory.create_message_mock(message_id=message_id, answer="Message answer")
session = MagicMock()
session.get.return_value = message
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"message audio"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
# Act
result = AudioService.transcript_tts(
app_model=app,
session=session,
message_id=message_id,
voice="message-voice",
)
# Assert
assert result == b"message audio"
session.get.assert_called_once_with(Message, message_id)
mock_model_instance.invoke_tts.assert_called_once_with(
content_text="Message answer",
voice="message-voice",
)
def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory):
"""Test that TTS raises error when text is missing."""
@ -515,7 +553,7 @@ class TestAudioServiceTTS:
# Act & Assert
with pytest.raises(ValueError, match="Text is required"):
AudioService.transcript_tts(app_model=app, text=None)
AudioService.transcript_tts(app_model=app, session=MagicMock(), text=None)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_raises_error_when_no_voices_available(
@ -539,7 +577,7 @@ class TestAudioServiceTTS:
# Act & Assert
with pytest.raises(ValueError, match="Sorry, no voice available"):
AudioService.transcript_tts(app_model=app, text="Test")
AudioService.transcript_tts(app_model=app, session=MagicMock(), text="Test")
class TestAudioServiceTTSVoices: