mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 21:11:16 +08:00
refactor: inject session into audio TTS (#37849)
This commit is contained in:
parent
1c1b20aa46
commit
31c08faded
@ -30,6 +30,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
@ -142,6 +143,7 @@ class ChatMessageTextApi(Resource):
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=payload.text,
|
||||
voice=payload.voice,
|
||||
message_id=payload.message_id,
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import InstalledApp
|
||||
from services.audio_service import AudioService
|
||||
@ -99,7 +100,13 @@ class ChatTextApi(InstalledAppResource):
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
message_id=message_id,
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
|
||||
@ -419,7 +419,13 @@ class TrialChatTextApi(TrialAppResource):
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
message_id=message_id,
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
||||
@ -23,6 +23,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
@ -177,7 +178,12 @@ class TextApi(Resource):
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
end_user=end_user.external_user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -22,6 +22,7 @@ from controllers.web.error import (
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
@ -130,7 +131,12 @@ class TextApi(WebApiResource):
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
end_user=end_user.external_user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -5,11 +5,11 @@ from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from constants import AUDIO_EXTENSIONS
|
||||
from core.model_manager import ModelManager
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, Message
|
||||
@ -77,6 +77,8 @@ class AudioService:
|
||||
def transcript_tts(
|
||||
cls,
|
||||
app_model: App,
|
||||
*,
|
||||
session: Session | scoped_session,
|
||||
text: str | None = None,
|
||||
voice: str | None = None,
|
||||
end_user: str | None = None,
|
||||
@ -87,7 +89,7 @@ class AudioService:
|
||||
if voice is None:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if is_draft:
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model, session=session)
|
||||
else:
|
||||
workflow = app_model.workflow
|
||||
if (
|
||||
@ -132,7 +134,7 @@ class AudioService:
|
||||
uuid.UUID(message_id)
|
||||
except ValueError:
|
||||
return None
|
||||
message = db.session.get(Message, message_id)
|
||||
message = session.get(Message, message_id)
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
@ -142,7 +142,7 @@ class WorkflowService:
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(
|
||||
self, app_model: App, workflow_id: str | None = None, session: Session | None = None
|
||||
self, app_model: App, workflow_id: str | None = None, session: Session | scoped_session | None = None
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
@ -169,7 +169,7 @@ class WorkflowService:
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(
|
||||
self, app_model: App, workflow_id: str, session: Session | None = None
|
||||
self, app_model: App, workflow_id: str, session: Session | scoped_session | None = None
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
|
||||
@ -158,6 +158,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
|
||||
with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager):
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=db_session_with_containers,
|
||||
message_id=message.id,
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
@ -174,6 +175,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
|
||||
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=db_session_with_containers,
|
||||
message_id="invalid-uuid",
|
||||
)
|
||||
|
||||
@ -185,6 +187,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
|
||||
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=db_session_with_containers,
|
||||
message_id=str(uuid4()),
|
||||
)
|
||||
|
||||
@ -205,6 +208,7 @@ class TestAudioServiceTranscriptTTSMessageLookup:
|
||||
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=db_session_with_containers,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
|
||||
@ -176,6 +176,7 @@ class TestAudioServiceMockedBehavior:
|
||||
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=mock_app,
|
||||
session=Mock(),
|
||||
text="Hello world",
|
||||
voice="nova",
|
||||
end_user="user_123",
|
||||
|
||||
@ -398,6 +398,7 @@ class TestAudioServiceTTS:
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=MagicMock(),
|
||||
text="Hello world",
|
||||
voice="en-US-Neural",
|
||||
end_user="user-123",
|
||||
@ -432,6 +433,7 @@ class TestAudioServiceTTS:
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=MagicMock(),
|
||||
text="Test",
|
||||
)
|
||||
|
||||
@ -465,6 +467,7 @@ class TestAudioServiceTTS:
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=MagicMock(),
|
||||
text="Test",
|
||||
)
|
||||
|
||||
@ -496,17 +499,52 @@ class TestAudioServiceTTS:
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"draft audio"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
session = MagicMock()
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=session,
|
||||
text="Draft test",
|
||||
is_draft=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"draft audio"
|
||||
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
|
||||
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app, session=session)
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_message_id_uses_provided_session(
|
||||
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test TTS message lookup uses the injected session."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock(mode=AppMode.CHAT)
|
||||
message_id = "00000000-0000-0000-0000-000000000001"
|
||||
message = factory.create_message_mock(message_id=message_id, answer="Message answer")
|
||||
session = MagicMock()
|
||||
session.get.return_value = message
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"message audio"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
app_model=app,
|
||||
session=session,
|
||||
message_id=message_id,
|
||||
voice="message-voice",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == b"message audio"
|
||||
session.get.assert_called_once_with(Message, message_id)
|
||||
mock_model_instance.invoke_tts.assert_called_once_with(
|
||||
content_text="Message answer",
|
||||
voice="message-voice",
|
||||
)
|
||||
|
||||
def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory):
|
||||
"""Test that TTS raises error when text is missing."""
|
||||
@ -515,7 +553,7 @@ class TestAudioServiceTTS:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Text is required"):
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
AudioService.transcript_tts(app_model=app, session=MagicMock(), text=None)
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(
|
||||
@ -539,7 +577,7 @@ class TestAudioServiceTTS:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Sorry, no voice available"):
|
||||
AudioService.transcript_tts(app_model=app, text="Test")
|
||||
AudioService.transcript_tts(app_model=app, session=MagicMock(), text="Test")
|
||||
|
||||
|
||||
class TestAudioServiceTTSVoices:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user