From 7b5c371b9d79e553f433a7d88c6fa51d2aa08425 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 10 May 2026 15:04:42 +0900 Subject: [PATCH] chore: api para type (#35985) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/admin.py | 9 ++- api/controllers/console/app/annotation.py | 75 ++++++++----------- api/controllers/console/app/app.py | 9 ++- api/controllers/console/app/ops_trace.py | 17 +++-- .../console/explore/recommended_app.py | 6 +- api/controllers/console/files.py | 2 +- api/controllers/console/workspace/plugin.py | 2 +- .../console/workspace/workspace.py | 2 +- api/controllers/files/upload.py | 2 +- api/controllers/service_api/app/file.py | 2 +- .../service_api/dataset/document.py | 4 +- .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/controllers/web/files.py | 2 +- .../prompt_template/manager.py | 2 +- api/services/annotation_service.py | 2 +- api/services/audio_service.py | 2 +- .../trigger_subscription_builder_service.py | 8 +- api/services/trigger/webhook_service.py | 2 +- .../services/test_webhook_service.py | 4 +- .../controllers/files/test_upload.py | 4 +- .../test_prompt_template_manager.py | 15 ++-- .../unit_tests/services/test_audio_service.py | 53 ++++++++----- .../services/test_webhook_service.py | 8 +- 23 files changed, 120 insertions(+), 114 deletions(-) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index bb2f477e3d..ae2b1007dd 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,6 +3,7 @@ import io from collections.abc import Callable from functools import wraps from typing import cast +from uuid import UUID from flask import request from flask_restx import Resource @@ -181,7 +182,7 @@ class InsertExploreAppApi(Resource): @console_ns.response(204, "App removed successfully") @only_edition_cloud @admin_required - def delete(self, app_id): + def delete(self, app_id: UUID): with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) @@ -394,11 +395,11 @@ class BatchAddNotificationAccountsApi(Resource): raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.") try: - content = file.read().decode("utf-8") + content = file.stream.read().decode("utf-8") except UnicodeDecodeError: try: - file.seek(0) - content = file.read().decode("gbk") + file.stream.seek(0) + content = file.stream.read().decode("gbk") except UnicodeDecodeError: raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 5970e55285..cfeaec4af9 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,4 +1,5 @@ from typing import Any, Literal +from uuid import UUID from flask import abort, make_response, request from flask_restx import Resource @@ -115,8 +116,7 @@ class AnnotationReplyActionApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id, action: Literal["enable", "disable"]): - app_id = str(app_id) + def post(self, app_id: UUID, action: Literal["enable", "disable"]): args = AnnotationReplyPayload.model_validate(console_ns.payload) match action: case "enable": @@ -125,9 +125,9 @@ class AnnotationReplyActionApi(Resource): "embedding_provider_name": args.embedding_provider_name, "embedding_model_name": args.embedding_model_name, } - result = AppAnnotationService.enable_app_annotation(enable_args, app_id) + result = AppAnnotationService.enable_app_annotation(enable_args, str(app_id)) case "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + result = AppAnnotationService.disable_app_annotation(str(app_id)) return result, 200 @@ -142,9 +142,8 @@ class AppAnnotationSettingDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - app_id = str(app_id) - result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) + def get(self, app_id: UUID): + result = AppAnnotationService.get_app_annotation_setting_by_app_id(str(app_id)) return result, 200 @@ -160,14 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id, annotation_setting_id): - app_id = str(app_id) + def post(self, app_id: UUID, annotation_setting_id): annotation_setting_id = str(annotation_setting_id) args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args) + result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args) return result, 200 @@ -183,7 +181,7 @@ class AnnotationReplyActionStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def get(self, app_id, job_id, action): + def get(self, app_id: UUID, job_id, action): job_id = str(job_id) app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) @@ -211,14 +209,13 @@ class AnnotationApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): + def get(self, app_id: UUID): args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page limit = args.limit keyword = args.keyword - app_id = str(app_id) - annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(str(app_id), page, limit, keyword) annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) response = AnnotationList( data=annotation_models, @@ -240,8 +237,7 @@ class AnnotationApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id): - app_id = str(app_id) + def post(self, app_id: UUID): args = CreateAnnotationPayload.model_validate(console_ns.payload) upsert_args: UpsertAnnotationArgs = {} if args.answer is not None: @@ -252,15 +248,14 @@ class AnnotationApi(Resource): upsert_args["message_id"] = args.message_id if args.question is not None: upsert_args["question"] = args.question - annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, str(app_id)) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @account_initialization_required @edit_permission_required - def delete(self, app_id): - app_id = str(app_id) + def delete(self, app_id: UUID): # Use request.args.getlist to get annotation_ids array directly annotation_ids = request.args.getlist("annotation_id") @@ -274,11 +269,11 @@ class AnnotationApi(Resource): "message": "annotation_ids are required if the parameter is provided.", }, 400 - result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids) + result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids) return result, 204 # If no annotation_ids are provided, handle clearing all annotations else: - AppAnnotationService.clear_all_annotations(app_id) + AppAnnotationService.clear_all_annotations(str(app_id)) return {"result": "success"}, 204 @@ -297,9 +292,8 @@ class AnnotationExportApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - app_id = str(app_id) - annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) + def get(self, app_id: UUID): + annotation_list = AppAnnotationService.export_annotation_list_by_app_id(str(app_id)) annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") @@ -325,26 +319,22 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id, annotation_id): - app_id = str(app_id) - annotation_id = str(annotation_id) + def post(self, app_id: UUID, annotation_id: UUID): args = UpdateAnnotationPayload.model_validate(console_ns.payload) update_args: UpdateAnnotationArgs = {} if args.answer is not None: update_args["answer"] = args.answer if args.question is not None: update_args["question"] = args.question - annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id) + annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id)) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @account_initialization_required @edit_permission_required - def delete(self, app_id, annotation_id): - app_id = str(app_id) - annotation_id = str(annotation_id) - AppAnnotationService.delete_app_annotation(app_id, annotation_id) + def delete(self, app_id: UUID, annotation_id: UUID): + AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id)) return {"result": "success"}, 204 @@ -365,11 +355,9 @@ class AnnotationBatchImportApi(Resource): @annotation_import_rate_limit @annotation_import_concurrency_limit @edit_permission_required - def post(self, app_id): + def post(self, app_id: UUID): from configs import dify_config - app_id = str(app_id) - # check file if "file" not in request.files: raise NoFileUploadedError() @@ -385,9 +373,9 @@ class AnnotationBatchImportApi(Resource): raise ValueError("Invalid file type. Only CSV files are allowed") # Check file size before processing - file.seek(0, 2) # Seek to end of file - file_size = file.tell() - file.seek(0) # Reset to beginning + file.stream.seek(0, 2) # Seek to end of file + file_size = file.stream.tell() + file.stream.seek(0) # Reset to beginning max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > max_size_bytes: @@ -400,7 +388,7 @@ class AnnotationBatchImportApi(Resource): if file_size == 0: raise ValueError("The uploaded file is empty") - return AppAnnotationService.batch_import_app_annotations(app_id, file) + return AppAnnotationService.batch_import_app_annotations(str(app_id), file) @console_ns.route("/apps//annotations/batch-import-status/") @@ -415,8 +403,7 @@ class AnnotationBatchImportStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def get(self, app_id, job_id): - job_id = str(job_id) + def get(self, app_id: UUID, job_id: UUID): indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" cache_result = redis_client.get(indexing_cache_key) if cache_result is None: @@ -450,13 +437,11 @@ class AnnotationHitHistoryListApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id, annotation_id): + def get(self, app_id: UUID, annotation_id: UUID): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - app_id = str(app_id) - annotation_id = str(annotation_id) annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( - app_id, annotation_id, page, limit + str(app_id), str(annotation_id), page, limit ) history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( annotation_hit_history_list, from_attributes=True diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5023d46893..a8ab5bec48 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -3,6 +3,7 @@ import re import uuid from datetime import datetime from typing import Any, Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -840,10 +841,10 @@ class AppTraceApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + def get(self, app_id: UUID): """Get app trace""" with session_factory.create_session() as session: - app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session) + app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session) return app_trace_config @@ -857,12 +858,12 @@ class AppTraceApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id): + def post(self, app_id: UUID): # add app trace args = AppTracePayload.model_validate(console_ns.payload) OpsTraceManager.update_app_tracing_config( - app_id=app_id, + app_id=str(app_id), enabled=args.enabled, tracing_provider=args.tracing_provider, ) diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index ee2fc39f86..9227d00a21 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,5 @@ from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource, fields @@ -42,11 +43,11 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + def get(self, app_id: UUID): args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: - trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) if not trace_config: return {"has_not_configured": True} return trace_config @@ -64,13 +65,13 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + def post(self, app_id: UUID): """Create a new trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.create_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigIsExist() @@ -89,13 +90,13 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, app_id): + def patch(self, app_id: UUID): """Update an existing trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.update_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigNotExist() @@ -112,12 +113,12 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, app_id): + def delete(self, app_id: UUID): """Delete an existing trace app configuration""" args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: - result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index bd0e875666..5821b91489 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,4 +1,5 @@ from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource @@ -99,6 +100,5 @@ class RecommendedAppListApi(Resource): class RecommendedAppApi(Resource): @login_required @account_initialization_required - def get(self, app_id): - app_id = str(app_id) - return RecommendedAppService.get_recommend_app_detail(app_id) + def get(self, app_id: UUID): + return RecommendedAppService.get_recommend_app_detail(str(app_id)) diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 109a3cd0d3..9fa5b0f5c1 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -82,7 +82,7 @@ class FileApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source=source, diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 93e7f3acab..a6d4a60beb 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -177,7 +177,7 @@ def _read_upload_content(file: FileStorage, max_size: int) -> bytes: FileStorage.content_length is not reliable for multipart test uploads and may be zero even when content exists, so the controllers validate against the loaded bytes instead. """ - content = file.read() + content = file.stream.read() if len(content) > max_size: raise ValueError("File size exceeds the maximum allowed size") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index a15d8b5918..84890f0443 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -321,7 +321,7 @@ class WebappLogoWorkspaceApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 462e9ef58e..7d588b95dd 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -100,7 +100,7 @@ class PluginUploadFileApi(Resource): tool_file = ToolFileManager().create_file_by_raw( user_id=user.id, tenant_id=tenant_id, - file_binary=file.read(), + file_binary=file.stream.read(), mimetype=mimetype, filename=filename, conversation_id=None, diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 6f6dadf768..687d34076d 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -58,7 +58,7 @@ class FileApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=end_user, ) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 1cf757912f..cb48fe6715 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -432,7 +432,7 @@ class DocumentAddByFileApi(DatasetApiResource): raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source="datasets", @@ -506,7 +506,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source="datasets", diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 2dc98bfbf7..8bc43bccd5 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -241,7 +241,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, ) diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 0036c90800..6128490104 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -73,7 +73,7 @@ class FileApi(WebApiResource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=end_user, source="datasets" if source == "datasets" else None, diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 4c07445df3..f4bbbe5d8b 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -75,7 +75,7 @@ class PromptTemplateConfigManager: if not config.get("prompt_type"): config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE - prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] + prompt_type_vals = list(PromptTemplateEntity.PromptType) if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 0229a1f43a..aa6b8ffc6e 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -425,7 +425,7 @@ class AppAnnotationService: return {"deleted_count": deleted_count} @classmethod - def batch_import_app_annotations(cls, app_id, file: FileStorage): + def batch_import_app_annotations(cls, app_id: str, file: FileStorage): """ Batch import annotations from CSV file with enhanced security checks. diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 60948e652b..c80b2f43fd 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -54,7 +54,7 @@ class AudioService: if extension not in [f"audio/{ext}" for ext in AUDIO_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() - file_content = file.read() + file_content = file.stream.read() file_size = len(file_content) if file_size > FILE_SIZE_LIMIT: diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index 889717df72..cff735b39d 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -121,9 +121,7 @@ class TriggerSubscriptionBuilderService: if not subscription_builder.name: raise ValueError("Subscription builder name is required") - credential_type = CredentialType.of( - subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value - ) + credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED) if credential_type == CredentialType.UNAUTHORIZED: # manually create TriggerProviderService.add_trigger_subscription( @@ -321,9 +319,7 @@ class TriggerSubscriptionBuilderService: raise ValueError("Subscription builder name is required") # Build - credential_type = CredentialType.of( - subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value - ) + credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED) if credential_type == CredentialType.UNAUTHORIZED: # manually create TriggerProviderService.add_trigger_subscription( diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 5d99900a04..592f678421 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -402,7 +402,7 @@ class WebhookService: for name, file in files.items(): if file and file.filename: try: - file_content = file.read() + file_content = file.stream.read() mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream" file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) processed_files[name] = file_obj.to_dict() diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 6d5c7380b7..52b1229302 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -543,8 +543,8 @@ class TestWebhookService: "bad_file": MagicMock(filename="test.bad", content_type="text/plain"), } - files["good_file"].read.return_value = b"content" - files["bad_file"].read.side_effect = Exception("Read error") + files["good_file"].stream.read.return_value = b"content" + files["bad_file"].stream.read.side_effect = Exception("Read error") webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" diff --git a/api/tests/unit_tests/controllers/files/test_upload.py b/api/tests/unit_tests/controllers/files/test_upload.py index e8f3cd4b66..ff6ba0e9a1 100644 --- a/api/tests/unit_tests/controllers/files/test_upload.py +++ b/api/tests/unit_tests/controllers/files/test_upload.py @@ -1,3 +1,4 @@ +import io import types from unittest.mock import patch @@ -30,9 +31,10 @@ class DummyFile: self.filename = filename self.mimetype = mimetype self._content = content + self.stream = io.BytesIO(content) def read(self): - return self._content + return self.stream.read() class DummyToolFile: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py index 3fd21ab22b..62e1d22129 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -1,3 +1,4 @@ +from collections import UserString from unittest.mock import MagicMock import pytest @@ -12,21 +13,25 @@ from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( # ----------------------------- -class DummyEnumValue: +class DummyEnumValue(UserString): def __init__(self, value): + super().__init__(value) self.value = value class DummyPromptType: def __init__(self): - self.SIMPLE = "simple" - self.ADVANCED = "advanced" + self.SIMPLE = DummyEnumValue("simple") + self.ADVANCED = DummyEnumValue("advanced") def value_of(self, value): - return value + for enum_value in self: + if enum_value.value == value: + return enum_value + raise ValueError(f"invalid prompt type value {value}") def __iter__(self): - return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + return iter([self.SIMPLE, self.ADVANCED]) # ----------------------------- diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 83258fd1b7..5d148974f8 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -173,7 +173,8 @@ class AudioServiceTestDataFactory: file = Mock(spec=FileStorage) file.filename = filename file.mimetype = mimetype - file.read = Mock(return_value=content) + file.stream = Mock() + file.stream.read = Mock(return_value=content) for key, value in kwargs.items(): setattr(file, key, value) return file @@ -216,7 +217,7 @@ class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): + def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful ASR transcription in CHAT mode.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -241,7 +242,9 @@ class TestAudioServiceASR: mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): + def test_transcript_asr_success_advanced_chat_mode( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}}) @@ -263,7 +266,7 @@ class TestAudioServiceASR: # Assert assert result == {"text": "Workflow transcribed text"} - def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory): + def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when speech-to-text is disabled in CHAT mode.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False}) @@ -277,7 +280,9 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory): + def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode( + self, factory: AudioServiceTestDataFactory + ): """Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode.""" # Arrange workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}}) @@ -291,7 +296,7 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_workflow_missing(self, factory): + def test_transcript_asr_raises_error_when_workflow_missing(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when workflow is missing in WORKFLOW mode.""" # Arrange app = factory.create_app_mock( @@ -304,7 +309,7 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory): + def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when no file is uploaded.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -317,7 +322,7 @@ class TestAudioServiceASR: with pytest.raises(NoAudioUploadedServiceError): AudioService.transcript_asr(app_model=app, file=None) - def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory): + def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error for unsupported audio file types.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -331,7 +336,7 @@ class TestAudioServiceASR: with pytest.raises(UnsupportedAudioTypeServiceError): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_for_large_file(self, factory): + def test_transcript_asr_raises_error_for_large_file(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when file exceeds size limit (30MB).""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -348,7 +353,9 @@ class TestAudioServiceASR: AudioService.transcript_asr(app_model=app, file=file) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): + def test_transcript_asr_raises_error_when_no_model_instance( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that ASR raises error when no model instance is available.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -371,7 +378,7 @@ class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): + def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful TTS with text input.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -405,7 +412,7 @@ class TestAudioServiceTTS: ) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): + def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test TTS uses default voice when none specified.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -435,7 +442,9 @@ class TestAudioServiceTTS: assert call_args.kwargs["voice"] == "default-voice" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): + def test_transcript_tts_gets_first_available_voice_when_none_configured( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test TTS gets first available voice when none is configured.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -467,7 +476,7 @@ class TestAudioServiceTTS: @patch("services.audio_service.WorkflowService", autospec=True) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_workflow_mode_with_draft( - self, mock_model_manager_class, mock_workflow_service_class, factory + self, mock_model_manager_class, mock_workflow_service_class, factory: AudioServiceTestDataFactory ): """Test TTS in WORKFLOW mode with draft workflow.""" # Arrange @@ -499,7 +508,7 @@ class TestAudioServiceTTS: assert result == b"draft audio" mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app) - def test_transcript_tts_raises_error_when_text_missing(self, factory): + def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory): """Test that TTS raises error when text is missing.""" # Arrange app = factory.create_app_mock() @@ -509,7 +518,9 @@ class TestAudioServiceTTS: AudioService.transcript_tts(app_model=app, text=None) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): + def test_transcript_tts_raises_error_when_no_voices_available( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS raises error when no voices are available.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -535,7 +546,7 @@ class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful retrieval of TTS voices.""" # Arrange tenant_id = "tenant-123" @@ -560,7 +571,9 @@ class TestAudioServiceTTSVoices: mock_model_instance.get_tts_voices.assert_called_once_with(language) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_raises_error_when_no_model_instance( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS voices raises error when no model instance is available.""" # Arrange tenant_id = "tenant-123" @@ -575,7 +588,9 @@ class TestAudioServiceTTSVoices: AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_propagates_exceptions( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS voices propagates exceptions from model instance.""" # Arrange tenant_id = "tenant-123" diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 95edc436d7..a2b56fe777 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -268,8 +268,8 @@ class TestWebhookServiceUnit: } # Mock file reads - files["file1"].read.return_value = b"content1" - files["file2"].read.return_value = b"content2" + files["file1"].stream.read.return_value = b"content1" + files["file2"].stream.read.return_value = b"content2" webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -304,8 +304,8 @@ class TestWebhookServiceUnit: "bad_file": MagicMock(filename="test.bad", content_type="text/plain"), } - files["good_file"].read.return_value = b"content" - files["bad_file"].read.side_effect = Exception("Read error") + files["good_file"].stream.read.return_value = b"content" + files["bad_file"].stream.read.side_effect = Exception("Read error") webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant"