chore: api para type (#35985)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-05-10 15:04:42 +09:00 committed by GitHub
parent c67ce6f66d
commit 7b5c371b9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 120 additions and 114 deletions

View File

@ -3,6 +3,7 @@ import io
from collections.abc import Callable
from functools import wraps
from typing import cast
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -181,7 +182,7 @@ class InsertExploreAppApi(Resource):
@console_ns.response(204, "App removed successfully")
@only_edition_cloud
@admin_required
def delete(self, app_id):
def delete(self, app_id: UUID):
with session_factory.create_session() as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
@ -394,11 +395,11 @@ class BatchAddNotificationAccountsApi(Resource):
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
try:
content = file.read().decode("utf-8")
content = file.stream.read().decode("utf-8")
except UnicodeDecodeError:
try:
file.seek(0)
content = file.read().decode("gbk")
file.stream.seek(0)
content = file.stream.read().decode("gbk")
except UnicodeDecodeError:
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")

View File

@ -1,4 +1,5 @@
from typing import Any, Literal
from uuid import UUID
from flask import abort, make_response, request
from flask_restx import Resource
@ -115,8 +116,7 @@ class AnnotationReplyActionApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
def post(self, app_id: UUID, action: Literal["enable", "disable"]):
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
@ -125,9 +125,9 @@ class AnnotationReplyActionApi(Resource):
"embedding_provider_name": args.embedding_provider_name,
"embedding_model_name": args.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
result = AppAnnotationService.enable_app_annotation(enable_args, str(app_id))
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
result = AppAnnotationService.disable_app_annotation(str(app_id))
return result, 200
@ -142,9 +142,8 @@ class AppAnnotationSettingDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
app_id = str(app_id)
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
def get(self, app_id: UUID):
result = AppAnnotationService.get_app_annotation_setting_by_app_id(str(app_id))
return result, 200
@ -160,14 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id, annotation_setting_id):
app_id = str(app_id)
def post(self, app_id: UUID, annotation_setting_id):
annotation_setting_id = str(annotation_setting_id)
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args)
return result, 200
@ -183,7 +181,7 @@ class AnnotationReplyActionStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id, action):
def get(self, app_id: UUID, job_id, action):
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
@ -211,14 +209,13 @@ class AnnotationApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
def get(self, app_id: UUID):
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
limit = args.limit
keyword = args.keyword
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(str(app_id), page, limit, keyword)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
@ -240,8 +237,7 @@ class AnnotationApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
def post(self, app_id: UUID):
args = CreateAnnotationPayload.model_validate(console_ns.payload)
upsert_args: UpsertAnnotationArgs = {}
if args.answer is not None:
@ -252,15 +248,14 @@ class AnnotationApi(Resource):
upsert_args["message_id"] = args.message_id
if args.question is not None:
upsert_args["question"] = args.question
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, str(app_id))
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id):
app_id = str(app_id)
def delete(self, app_id: UUID):
# Use request.args.getlist to get annotation_ids array directly
annotation_ids = request.args.getlist("annotation_id")
@ -274,11 +269,11 @@ class AnnotationApi(Resource):
"message": "annotation_ids are required if the parameter is provided.",
}, 400
result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids)
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
return result, 204
# If no annotation_ids are provided, handle clearing all annotations
else:
AppAnnotationService.clear_all_annotations(app_id)
AppAnnotationService.clear_all_annotations(str(app_id))
return {"result": "success"}, 204
@ -297,9 +292,8 @@ class AnnotationExportApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
def get(self, app_id: UUID):
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(str(app_id))
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
@ -325,26 +319,22 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
def post(self, app_id: UUID, annotation_id: UUID):
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
update_args: UpdateAnnotationArgs = {}
if args.answer is not None:
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id))
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
def delete(self, app_id: UUID, annotation_id: UUID):
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
return {"result": "success"}, 204
@ -365,11 +355,9 @@ class AnnotationBatchImportApi(Resource):
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
def post(self, app_id: UUID):
from configs import dify_config
app_id = str(app_id)
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -385,9 +373,9 @@ class AnnotationBatchImportApi(Resource):
raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
file.stream.seek(0, 2) # Seek to end of file
file_size = file.stream.tell()
file.stream.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
@ -400,7 +388,7 @@ class AnnotationBatchImportApi(Resource):
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
return AppAnnotationService.batch_import_app_annotations(str(app_id), file)
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
@ -415,8 +403,7 @@ class AnnotationBatchImportStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id, job_id):
job_id = str(job_id)
def get(self, app_id: UUID, job_id: UUID):
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
@ -450,13 +437,11 @@ class AnnotationHitHistoryListApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_id, annotation_id):
def get(self, app_id: UUID, annotation_id: UUID):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
app_id = str(app_id)
annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
str(app_id), str(annotation_id), page, limit
)
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
annotation_hit_history_list, from_attributes=True

View File

@ -3,6 +3,7 @@ import re
import uuid
from datetime import datetime
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -840,10 +841,10 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
def get(self, app_id: UUID):
"""Get app trace"""
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session)
app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session)
return app_trace_config
@ -857,12 +858,12 @@ class AppTraceApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id):
def post(self, app_id: UUID):
# add app trace
args = AppTracePayload.model_validate(console_ns.payload)
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
app_id=str(app_id),
enabled=args.enabled,
tracing_provider=args.tracing_provider,
)

View File

@ -1,4 +1,5 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import Resource, fields
@ -42,11 +43,11 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
def get(self, app_id: UUID):
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
if not trace_config:
return {"has_not_configured": True}
return trace_config
@ -64,13 +65,13 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
def post(self, app_id: UUID):
"""Create a new trace app configuration"""
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigIsExist()
@ -89,13 +90,13 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, app_id):
def patch(self, app_id: UUID):
"""Update an existing trace app configuration"""
args = TraceConfigPayload.model_validate(console_ns.payload)
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
)
if not result:
raise TracingConfigNotExist()
@ -112,12 +113,12 @@ class TraceAppConfigApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id):
def delete(self, app_id: UUID):
"""Delete an existing trace app configuration"""
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@ -1,4 +1,5 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -99,6 +100,5 @@ class RecommendedAppListApi(Resource):
class RecommendedAppApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
app_id = str(app_id)
return RecommendedAppService.get_recommend_app_detail(app_id)
def get(self, app_id: UUID):
return RecommendedAppService.get_recommend_app_detail(str(app_id))

View File

@ -82,7 +82,7 @@ class FileApi(Resource):
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
content=file.stream.read(),
mimetype=file.mimetype,
user=current_user,
source=source,

View File

@ -177,7 +177,7 @@ def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
FileStorage.content_length is not reliable for multipart test uploads and may be zero even when
content exists, so the controllers validate against the loaded bytes instead.
"""
content = file.read()
content = file.stream.read()
if len(content) > max_size:
raise ValueError("File size exceeds the maximum allowed size")

View File

@ -321,7 +321,7 @@ class WebappLogoWorkspaceApi(Resource):
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
content=file.stream.read(),
mimetype=file.mimetype,
user=current_user,
)

View File

@ -100,7 +100,7 @@ class PluginUploadFileApi(Resource):
tool_file = ToolFileManager().create_file_by_raw(
user_id=user.id,
tenant_id=tenant_id,
file_binary=file.read(),
file_binary=file.stream.read(),
mimetype=mimetype,
filename=filename,
conversation_id=None,

View File

@ -58,7 +58,7 @@ class FileApi(Resource):
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
content=file.stream.read(),
mimetype=file.mimetype,
user=end_user,
)

View File

@ -432,7 +432,7 @@ class DocumentAddByFileApi(DatasetApiResource):
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
content=file.stream.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
@ -506,7 +506,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
content=file.stream.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",

View File

@ -241,7 +241,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
content=file.stream.read(),
mimetype=file.mimetype,
user=current_user,
)

View File

@ -73,7 +73,7 @@ class FileApi(WebApiResource):
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
content=file.stream.read(),
mimetype=file.mimetype,
user=end_user,
source="datasets" if source == "datasets" else None,

View File

@ -75,7 +75,7 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"):
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
prompt_type_vals = list(PromptTemplateEntity.PromptType)
if config["prompt_type"] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}")

View File

@ -425,7 +425,7 @@ class AppAnnotationService:
return {"deleted_count": deleted_count}
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage):
def batch_import_app_annotations(cls, app_id: str, file: FileStorage):
"""
Batch import annotations from CSV file with enhanced security checks.

View File

@ -54,7 +54,7 @@ class AudioService:
if extension not in [f"audio/{ext}" for ext in AUDIO_EXTENSIONS]:
raise UnsupportedAudioTypeServiceError()
file_content = file.read()
file_content = file.stream.read()
file_size = len(file_content)
if file_size > FILE_SIZE_LIMIT:

View File

@ -121,9 +121,7 @@ class TriggerSubscriptionBuilderService:
if not subscription_builder.name:
raise ValueError("Subscription builder name is required")
credential_type = CredentialType.of(
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
)
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED)
if credential_type == CredentialType.UNAUTHORIZED:
# manually create
TriggerProviderService.add_trigger_subscription(
@ -321,9 +319,7 @@ class TriggerSubscriptionBuilderService:
raise ValueError("Subscription builder name is required")
# Build
credential_type = CredentialType.of(
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
)
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED)
if credential_type == CredentialType.UNAUTHORIZED:
# manually create
TriggerProviderService.add_trigger_subscription(

View File

@ -402,7 +402,7 @@ class WebhookService:
for name, file in files.items():
if file and file.filename:
try:
file_content = file.read()
file_content = file.stream.read()
mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
processed_files[name] = file_obj.to_dict()

View File

@ -543,8 +543,8 @@ class TestWebhookService:
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
}
files["good_file"].read.return_value = b"content"
files["bad_file"].read.side_effect = Exception("Read error")
files["good_file"].stream.read.return_value = b"content"
files["bad_file"].stream.read.side_effect = Exception("Read error")
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"

View File

@ -1,3 +1,4 @@
import io
import types
from unittest.mock import patch
@ -30,9 +31,10 @@ class DummyFile:
self.filename = filename
self.mimetype = mimetype
self._content = content
self.stream = io.BytesIO(content)
def read(self):
return self._content
return self.stream.read()
class DummyToolFile:

View File

@ -1,3 +1,4 @@
from collections import UserString
from unittest.mock import MagicMock
import pytest
@ -12,21 +13,25 @@ from core.app.app_config.easy_ui_based_app.prompt_template.manager import (
# -----------------------------
class DummyEnumValue:
class DummyEnumValue(UserString):
def __init__(self, value):
super().__init__(value)
self.value = value
class DummyPromptType:
def __init__(self):
self.SIMPLE = "simple"
self.ADVANCED = "advanced"
self.SIMPLE = DummyEnumValue("simple")
self.ADVANCED = DummyEnumValue("advanced")
def value_of(self, value):
return value
for enum_value in self:
if enum_value.value == value:
return enum_value
raise ValueError(f"invalid prompt type value {value}")
def __iter__(self):
return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")])
return iter([self.SIMPLE, self.ADVANCED])
# -----------------------------

View File

@ -173,7 +173,8 @@ class AudioServiceTestDataFactory:
file = Mock(spec=FileStorage)
file.filename = filename
file.mimetype = mimetype
file.read = Mock(return_value=content)
file.stream = Mock()
file.stream.read = Mock(return_value=content)
for key, value in kwargs.items():
setattr(file, key, value)
return file
@ -216,7 +217,7 @@ class TestAudioServiceASR:
"""Test speech-to-text (ASR) operations."""
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
"""Test successful ASR transcription in CHAT mode."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
@ -241,7 +242,9 @@ class TestAudioServiceASR:
mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123")
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
def test_transcript_asr_success_advanced_chat_mode(
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
):
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
# Arrange
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}})
@ -263,7 +266,7 @@ class TestAudioServiceASR:
# Assert
assert result == {"text": "Workflow transcribed text"}
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory):
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory: AudioServiceTestDataFactory):
"""Test that ASR raises error when speech-to-text is disabled in CHAT mode."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False})
@ -277,7 +280,9 @@ class TestAudioServiceASR:
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory):
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(
self, factory: AudioServiceTestDataFactory
):
"""Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode."""
# Arrange
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}})
@ -291,7 +296,7 @@ class TestAudioServiceASR:
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_workflow_missing(self, factory):
def test_transcript_asr_raises_error_when_workflow_missing(self, factory: AudioServiceTestDataFactory):
"""Test that ASR raises error when workflow is missing in WORKFLOW mode."""
# Arrange
app = factory.create_app_mock(
@ -304,7 +309,7 @@ class TestAudioServiceASR:
with pytest.raises(ValueError, match="Speech to text is not enabled"):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory):
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory: AudioServiceTestDataFactory):
"""Test that ASR raises error when no file is uploaded."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
@ -317,7 +322,7 @@ class TestAudioServiceASR:
with pytest.raises(NoAudioUploadedServiceError):
AudioService.transcript_asr(app_model=app, file=None)
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory):
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory: AudioServiceTestDataFactory):
"""Test that ASR raises error for unsupported audio file types."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
@ -331,7 +336,7 @@ class TestAudioServiceASR:
with pytest.raises(UnsupportedAudioTypeServiceError):
AudioService.transcript_asr(app_model=app, file=file)
def test_transcript_asr_raises_error_for_large_file(self, factory):
def test_transcript_asr_raises_error_for_large_file(self, factory: AudioServiceTestDataFactory):
"""Test that ASR raises error when file exceeds size limit (30MB)."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
@ -348,7 +353,9 @@ class TestAudioServiceASR:
AudioService.transcript_asr(app_model=app, file=file)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
def test_transcript_asr_raises_error_when_no_model_instance(
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
):
"""Test that ASR raises error when no model instance is available."""
# Arrange
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
@ -371,7 +378,7 @@ class TestAudioServiceTTS:
"""Test text-to-speech (TTS) operations."""
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
"""Test successful TTS with text input."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
@ -405,7 +412,7 @@ class TestAudioServiceTTS:
)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
"""Test TTS uses default voice when none specified."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
@ -435,7 +442,9 @@ class TestAudioServiceTTS:
assert call_args.kwargs["voice"] == "default-voice"
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
def test_transcript_tts_gets_first_available_voice_when_none_configured(
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
):
"""Test TTS gets first available voice when none is configured."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
@ -467,7 +476,7 @@ class TestAudioServiceTTS:
@patch("services.audio_service.WorkflowService", autospec=True)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_workflow_mode_with_draft(
self, mock_model_manager_class, mock_workflow_service_class, factory
self, mock_model_manager_class, mock_workflow_service_class, factory: AudioServiceTestDataFactory
):
"""Test TTS in WORKFLOW mode with draft workflow."""
# Arrange
@ -499,7 +508,7 @@ class TestAudioServiceTTS:
assert result == b"draft audio"
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
def test_transcript_tts_raises_error_when_text_missing(self, factory):
def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory):
"""Test that TTS raises error when text is missing."""
# Arrange
app = factory.create_app_mock()
@ -509,7 +518,9 @@ class TestAudioServiceTTS:
AudioService.transcript_tts(app_model=app, text=None)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
def test_transcript_tts_raises_error_when_no_voices_available(
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
):
"""Test that TTS raises error when no voices are available."""
# Arrange
app_model_config = factory.create_app_model_config_mock(
@ -535,7 +546,7 @@ class TestAudioServiceTTSVoices:
"""Test TTS voice listing operations."""
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
"""Test successful retrieval of TTS voices."""
# Arrange
tenant_id = "tenant-123"
@ -560,7 +571,9 @@ class TestAudioServiceTTSVoices:
mock_model_instance.get_tts_voices.assert_called_once_with(language)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
def test_transcript_tts_voices_raises_error_when_no_model_instance(
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
):
"""Test that TTS voices raises error when no model instance is available."""
# Arrange
tenant_id = "tenant-123"
@ -575,7 +588,9 @@ class TestAudioServiceTTSVoices:
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
def test_transcript_tts_voices_propagates_exceptions(
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
):
"""Test that TTS voices propagates exceptions from model instance."""
# Arrange
tenant_id = "tenant-123"

View File

@ -268,8 +268,8 @@ class TestWebhookServiceUnit:
}
# Mock file reads
files["file1"].read.return_value = b"content1"
files["file2"].read.return_value = b"content2"
files["file1"].stream.read.return_value = b"content1"
files["file2"].stream.read.return_value = b"content2"
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
@ -304,8 +304,8 @@ class TestWebhookServiceUnit:
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
}
files["good_file"].read.return_value = b"content"
files["bad_file"].read.side_effect = Exception("Read error")
files["good_file"].stream.read.return_value = b"content"
files["bad_file"].stream.read.side_effect = Exception("Read error")
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"