mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 22:28:55 +08:00
chore: api para type (#35985)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c67ce6f66d
commit
7b5c371b9d
@ -3,6 +3,7 @@ import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -181,7 +182,7 @@ class InsertExploreAppApi(Resource):
|
||||
@console_ns.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
def delete(self, app_id: UUID):
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||
@ -394,11 +395,11 @@ class BatchAddNotificationAccountsApi(Resource):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
content = file.stream.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
file.stream.seek(0)
|
||||
content = file.stream.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -115,8 +116,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
app_id = str(app_id)
|
||||
def post(self, app_id: UUID, action: Literal["enable", "disable"]):
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
match action:
|
||||
case "enable":
|
||||
@ -125,9 +125,9 @@ class AnnotationReplyActionApi(Resource):
|
||||
"embedding_provider_name": args.embedding_provider_name,
|
||||
"embedding_model_name": args.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, str(app_id))
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
result = AppAnnotationService.disable_app_annotation(str(app_id))
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -142,9 +142,8 @@ class AppAnnotationSettingDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
def get(self, app_id: UUID):
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(str(app_id))
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -160,14 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
app_id = str(app_id)
|
||||
def post(self, app_id: UUID, annotation_setting_id):
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
|
||||
result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -183,7 +181,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id, action):
|
||||
def get(self, app_id: UUID, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@ -211,14 +209,13 @@ class AnnotationApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
def get(self, app_id: UUID):
|
||||
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
keyword = args.keyword
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(str(app_id), page, limit, keyword)
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response = AnnotationList(
|
||||
data=annotation_models,
|
||||
@ -240,8 +237,7 @@ class AnnotationApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
def post(self, app_id: UUID):
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
upsert_args: UpsertAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
@ -252,15 +248,14 @@ class AnnotationApi(Resource):
|
||||
upsert_args["message_id"] = args.message_id
|
||||
if args.question is not None:
|
||||
upsert_args["question"] = args.question
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, str(app_id))
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id):
|
||||
app_id = str(app_id)
|
||||
def delete(self, app_id: UUID):
|
||||
|
||||
# Use request.args.getlist to get annotation_ids array directly
|
||||
annotation_ids = request.args.getlist("annotation_id")
|
||||
@ -274,11 +269,11 @@ class AnnotationApi(Resource):
|
||||
"message": "annotation_ids are required if the parameter is provided.",
|
||||
}, 400
|
||||
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids)
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return result, 204
|
||||
# If no annotation_ids are provided, handle clearing all annotations
|
||||
else:
|
||||
AppAnnotationService.clear_all_annotations(app_id)
|
||||
AppAnnotationService.clear_all_annotations(str(app_id))
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -297,9 +292,8 @@ class AnnotationExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
def get(self, app_id: UUID):
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(str(app_id))
|
||||
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
|
||||
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
|
||||
|
||||
@ -325,26 +319,22 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
def post(self, app_id: UUID, annotation_id: UUID):
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
update_args: UpdateAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id))
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
def delete(self, app_id: UUID, annotation_id: UUID):
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -365,11 +355,9 @@ class AnnotationBatchImportApi(Resource):
|
||||
@annotation_import_rate_limit
|
||||
@annotation_import_concurrency_limit
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
def post(self, app_id: UUID):
|
||||
from configs import dify_config
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@ -385,9 +373,9 @@ class AnnotationBatchImportApi(Resource):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
# Check file size before processing
|
||||
file.seek(0, 2) # Seek to end of file
|
||||
file_size = file.tell()
|
||||
file.seek(0) # Reset to beginning
|
||||
file.stream.seek(0, 2) # Seek to end of file
|
||||
file_size = file.stream.tell()
|
||||
file.stream.seek(0) # Reset to beginning
|
||||
|
||||
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
if file_size > max_size_bytes:
|
||||
@ -400,7 +388,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
if file_size == 0:
|
||||
raise ValueError("The uploaded file is empty")
|
||||
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
return AppAnnotationService.batch_import_app_annotations(str(app_id), file)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
@ -415,8 +403,7 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id):
|
||||
job_id = str(job_id)
|
||||
def get(self, app_id: UUID, job_id: UUID):
|
||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is None:
|
||||
@ -450,13 +437,11 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id, annotation_id):
|
||||
def get(self, app_id: UUID, annotation_id: UUID):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
|
||||
app_id, annotation_id, page, limit
|
||||
str(app_id), str(annotation_id), page, limit
|
||||
)
|
||||
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
|
||||
annotation_hit_history_list, from_attributes=True
|
||||
|
||||
@ -3,6 +3,7 @@ import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -840,10 +841,10 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
def get(self, app_id: UUID):
|
||||
"""Get app trace"""
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session)
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@ -857,12 +858,12 @@ class AppTraceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
def post(self, app_id: UUID):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
app_id=str(app_id),
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
@ -42,11 +43,11 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
def get(self, app_id: UUID):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
@ -64,13 +65,13 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
def post(self, app_id: UUID):
|
||||
"""Create a new trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@ -89,13 +90,13 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
def patch(self, app_id: UUID):
|
||||
"""Update an existing trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@ -112,12 +113,12 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
def delete(self, app_id: UUID):
|
||||
"""Delete an existing trace app configuration"""
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -99,6 +100,5 @@ class RecommendedAppListApi(Resource):
|
||||
class RecommendedAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
return RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
def get(self, app_id: UUID):
|
||||
return RecommendedAppService.get_recommend_app_detail(str(app_id))
|
||||
|
||||
@ -82,7 +82,7 @@ class FileApi(Resource):
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source=source,
|
||||
|
||||
@ -177,7 +177,7 @@ def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
|
||||
FileStorage.content_length is not reliable for multipart test uploads and may be zero even when
|
||||
content exists, so the controllers validate against the loaded bytes instead.
|
||||
"""
|
||||
content = file.read()
|
||||
content = file.stream.read()
|
||||
if len(content) > max_size:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
|
||||
@ -321,7 +321,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
@ -100,7 +100,7 @@ class PluginUploadFileApi(Resource):
|
||||
tool_file = ToolFileManager().create_file_by_raw(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
file_binary=file.read(),
|
||||
file_binary=file.stream.read(),
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
conversation_id=None,
|
||||
|
||||
@ -58,7 +58,7 @@ class FileApi(Resource):
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
)
|
||||
|
||||
@ -432,7 +432,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
@ -506,7 +506,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
|
||||
@ -241,7 +241,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
@ -73,7 +73,7 @@ class FileApi(WebApiResource):
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=end_user,
|
||||
source="datasets" if source == "datasets" else None,
|
||||
|
||||
@ -75,7 +75,7 @@ class PromptTemplateConfigManager:
|
||||
if not config.get("prompt_type"):
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
prompt_type_vals = list(PromptTemplateEntity.PromptType)
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
|
||||
|
||||
|
||||
@ -425,7 +425,7 @@ class AppAnnotationService:
|
||||
return {"deleted_count": deleted_count}
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
||||
def batch_import_app_annotations(cls, app_id: str, file: FileStorage):
|
||||
"""
|
||||
Batch import annotations from CSV file with enhanced security checks.
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ class AudioService:
|
||||
if extension not in [f"audio/{ext}" for ext in AUDIO_EXTENSIONS]:
|
||||
raise UnsupportedAudioTypeServiceError()
|
||||
|
||||
file_content = file.read()
|
||||
file_content = file.stream.read()
|
||||
file_size = len(file_content)
|
||||
|
||||
if file_size > FILE_SIZE_LIMIT:
|
||||
|
||||
@ -121,9 +121,7 @@ class TriggerSubscriptionBuilderService:
|
||||
if not subscription_builder.name:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
@ -321,9 +319,7 @@ class TriggerSubscriptionBuilderService:
|
||||
raise ValueError("Subscription builder name is required")
|
||||
|
||||
# Build
|
||||
credential_type = CredentialType.of(
|
||||
subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value
|
||||
)
|
||||
credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED)
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
# manually create
|
||||
TriggerProviderService.add_trigger_subscription(
|
||||
|
||||
@ -402,7 +402,7 @@ class WebhookService:
|
||||
for name, file in files.items():
|
||||
if file and file.filename:
|
||||
try:
|
||||
file_content = file.read()
|
||||
file_content = file.stream.read()
|
||||
mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||
file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger)
|
||||
processed_files[name] = file_obj.to_dict()
|
||||
|
||||
@ -543,8 +543,8 @@ class TestWebhookService:
|
||||
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
|
||||
}
|
||||
|
||||
files["good_file"].read.return_value = b"content"
|
||||
files["bad_file"].read.side_effect = Exception("Read error")
|
||||
files["good_file"].stream.read.return_value = b"content"
|
||||
files["bad_file"].stream.read.side_effect = Exception("Read error")
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import io
|
||||
import types
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -30,9 +31,10 @@ class DummyFile:
|
||||
self.filename = filename
|
||||
self.mimetype = mimetype
|
||||
self._content = content
|
||||
self.stream = io.BytesIO(content)
|
||||
|
||||
def read(self):
|
||||
return self._content
|
||||
return self.stream.read()
|
||||
|
||||
|
||||
class DummyToolFile:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections import UserString
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -12,21 +13,25 @@ from core.app.app_config.easy_ui_based_app.prompt_template.manager import (
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class DummyEnumValue:
|
||||
class DummyEnumValue(UserString):
|
||||
def __init__(self, value):
|
||||
super().__init__(value)
|
||||
self.value = value
|
||||
|
||||
|
||||
class DummyPromptType:
|
||||
def __init__(self):
|
||||
self.SIMPLE = "simple"
|
||||
self.ADVANCED = "advanced"
|
||||
self.SIMPLE = DummyEnumValue("simple")
|
||||
self.ADVANCED = DummyEnumValue("advanced")
|
||||
|
||||
def value_of(self, value):
|
||||
return value
|
||||
for enum_value in self:
|
||||
if enum_value.value == value:
|
||||
return enum_value
|
||||
raise ValueError(f"invalid prompt type value {value}")
|
||||
|
||||
def __iter__(self):
|
||||
return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")])
|
||||
return iter([self.SIMPLE, self.ADVANCED])
|
||||
|
||||
|
||||
# -----------------------------
|
||||
|
||||
@ -173,7 +173,8 @@ class AudioServiceTestDataFactory:
|
||||
file = Mock(spec=FileStorage)
|
||||
file.filename = filename
|
||||
file.mimetype = mimetype
|
||||
file.read = Mock(return_value=content)
|
||||
file.stream = Mock()
|
||||
file.stream.read = Mock(return_value=content)
|
||||
for key, value in kwargs.items():
|
||||
setattr(file, key, value)
|
||||
return file
|
||||
@ -216,7 +217,7 @@ class TestAudioServiceASR:
|
||||
"""Test speech-to-text (ASR) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
|
||||
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
|
||||
"""Test successful ASR transcription in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
@ -241,7 +242,9 @@ class TestAudioServiceASR:
|
||||
mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123")
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
|
||||
def test_transcript_asr_success_advanced_chat_mode(
|
||||
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}})
|
||||
@ -263,7 +266,7 @@ class TestAudioServiceASR:
|
||||
# Assert
|
||||
assert result == {"text": "Workflow transcribed text"}
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory):
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory: AudioServiceTestDataFactory):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in CHAT mode."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False})
|
||||
@ -277,7 +280,9 @@ class TestAudioServiceASR:
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory):
|
||||
def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(
|
||||
self, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode."""
|
||||
# Arrange
|
||||
workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}})
|
||||
@ -291,7 +296,7 @@ class TestAudioServiceASR:
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_workflow_missing(self, factory):
|
||||
def test_transcript_asr_raises_error_when_workflow_missing(self, factory: AudioServiceTestDataFactory):
|
||||
"""Test that ASR raises error when workflow is missing in WORKFLOW mode."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock(
|
||||
@ -304,7 +309,7 @@ class TestAudioServiceASR:
|
||||
with pytest.raises(ValueError, match="Speech to text is not enabled"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory):
|
||||
def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory: AudioServiceTestDataFactory):
|
||||
"""Test that ASR raises error when no file is uploaded."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
@ -317,7 +322,7 @@ class TestAudioServiceASR:
|
||||
with pytest.raises(NoAudioUploadedServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=None)
|
||||
|
||||
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory):
|
||||
def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory: AudioServiceTestDataFactory):
|
||||
"""Test that ASR raises error for unsupported audio file types."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
@ -331,7 +336,7 @@ class TestAudioServiceASR:
|
||||
with pytest.raises(UnsupportedAudioTypeServiceError):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
def test_transcript_asr_raises_error_for_large_file(self, factory):
|
||||
def test_transcript_asr_raises_error_for_large_file(self, factory: AudioServiceTestDataFactory):
|
||||
"""Test that ASR raises error when file exceeds size limit (30MB)."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
@ -348,7 +353,9 @@ class TestAudioServiceASR:
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
def test_transcript_asr_raises_error_when_no_model_instance(
|
||||
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test that ASR raises error when no model instance is available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True})
|
||||
@ -371,7 +378,7 @@ class TestAudioServiceTTS:
|
||||
"""Test text-to-speech (TTS) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
|
||||
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
|
||||
"""Test successful TTS with text input."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
@ -405,7 +412,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
|
||||
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
|
||||
"""Test TTS uses default voice when none specified."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
@ -435,7 +442,9 @@ class TestAudioServiceTTS:
|
||||
assert call_args.kwargs["voice"] == "default-voice"
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
|
||||
def test_transcript_tts_gets_first_available_voice_when_none_configured(
|
||||
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test TTS gets first available voice when none is configured."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
@ -467,7 +476,7 @@ class TestAudioServiceTTS:
|
||||
@patch("services.audio_service.WorkflowService", autospec=True)
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_workflow_mode_with_draft(
|
||||
self, mock_model_manager_class, mock_workflow_service_class, factory
|
||||
self, mock_model_manager_class, mock_workflow_service_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test TTS in WORKFLOW mode with draft workflow."""
|
||||
# Arrange
|
||||
@ -499,7 +508,7 @@ class TestAudioServiceTTS:
|
||||
assert result == b"draft audio"
|
||||
mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app)
|
||||
|
||||
def test_transcript_tts_raises_error_when_text_missing(self, factory):
|
||||
def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory):
|
||||
"""Test that TTS raises error when text is missing."""
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
@ -509,7 +518,9 @@ class TestAudioServiceTTS:
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(
|
||||
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test that TTS raises error when no voices are available."""
|
||||
# Arrange
|
||||
app_model_config = factory.create_app_model_config_mock(
|
||||
@ -535,7 +546,7 @@ class TestAudioServiceTTSVoices:
|
||||
"""Test TTS voice listing operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
|
||||
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory):
|
||||
"""Test successful retrieval of TTS voices."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -560,7 +571,9 @@ class TestAudioServiceTTSVoices:
|
||||
mock_model_instance.get_tts_voices.assert_called_once_with(language)
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
def test_transcript_tts_voices_raises_error_when_no_model_instance(
|
||||
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test that TTS voices raises error when no model instance is available."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
@ -575,7 +588,9 @@ class TestAudioServiceTTSVoices:
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
@patch("services.audio_service.ModelManager.for_tenant", autospec=True)
|
||||
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
|
||||
def test_transcript_tts_voices_propagates_exceptions(
|
||||
self, mock_model_manager_class, factory: AudioServiceTestDataFactory
|
||||
):
|
||||
"""Test that TTS voices propagates exceptions from model instance."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
@ -268,8 +268,8 @@ class TestWebhookServiceUnit:
|
||||
}
|
||||
|
||||
# Mock file reads
|
||||
files["file1"].read.return_value = b"content1"
|
||||
files["file2"].read.return_value = b"content2"
|
||||
files["file1"].stream.read.return_value = b"content1"
|
||||
files["file2"].stream.read.return_value = b"content2"
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
@ -304,8 +304,8 @@ class TestWebhookServiceUnit:
|
||||
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
|
||||
}
|
||||
|
||||
files["good_file"].read.return_value = b"content"
|
||||
files["bad_file"].read.side_effect = Exception("Read error")
|
||||
files["good_file"].stream.read.return_value = b"content"
|
||||
files["bad_file"].stream.read.side_effect = Exception("Read error")
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user