From ac222a4dd4f030e06a0e0b47daa7c11d0514f0d1 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 18:03:07 +0900 Subject: [PATCH] refactor: port api/controllers/console/app/audio.py api/controllers/console/app/message.py api/controllers/console/auth/data_source_oauth.py api/controllers/console/auth/forgot_password.py api/controllers/console/workspace/endpoint.py (#30680) --- api/controllers/console/app/audio.py | 16 ++--- api/controllers/console/app/message.py | 31 +++++---- .../console/auth/data_source_oauth.py | 33 +++++++-- .../console/auth/forgot_password.py | 50 ++++++++------ api/controllers/console/workspace/endpoint.py | 69 ++++++++++++++----- .../clickzetta_volume_storage.py | 3 +- 6 files changed, 135 insertions(+), 67 deletions(-) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d344ede466..941db325bf 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -33,7 +34,6 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class TextToSpeechPayload(BaseModel): @@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel): language: str = Field(..., description="Language code") -console_ns.schema_model( - TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechVoiceQuery.__name__, - TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AudioTranscriptResponse(BaseModel): + text: str = Field(description="Transcribed text from audio") + + +register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery) @console_ns.route("/apps//audio-to-text") @@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource): @console_ns.response( 200, "Audio transcription successful", - console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + console_ns.models[AudioTranscriptResponse.__name__], ) @console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(413, "Audio file too large") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 12ada8b798..0be3e0ec49 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, select from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft from services.message_service import MessageService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ChatMessagesQuery(BaseModel): @@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel): raise ValueError("has_comment must be a boolean value") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class AnnotationCountResponse(BaseModel): + count: int = Field(description="Number of annotations") -reg(ChatMessagesQuery) -reg(MessageFeedbackPayload) -reg(FeedbackExportQuery) +class SuggestedQuestionsResponse(BaseModel): + data: list[str] = Field(description="Suggested question") + + +register_schema_models( + console_ns, + ChatMessagesQuery, + MessageFeedbackPayload, + FeedbackExportQuery, + AnnotationCountResponse, + SuggestedQuestionsResponse, +) # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -231,7 +240,7 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): - args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatMessagesQuery.model_validate(request.args.to_dict()) conversation = ( db.session.query(Conversation) @@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource): @console_ns.response( 200, "Annotation count retrieved successfully", - console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + console_ns.models[AnnotationCountResponse.__name__], ) @get_app_model @setup_required @@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource): @console_ns.response( 200, "Suggested questions retrieved successfully", - console_ns.model( - "SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))} - ), + console_ns.models[SuggestedQuestionsResponse.__name__], ) @console_ns.response(404, "Message or conversation not found") @setup_required @@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource): @login_required @account_initialization_required def get(self, app_model): - args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FeedbackExportQuery.model_validate(request.args.to_dict()) # Import the service function from services.feedback_service import FeedbackService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0dd7d33ae9..3a3278ec9d 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,9 +2,11 @@ import logging import httpx from flask import current_app, redirect, request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field from configs import dify_config +from controllers.common.schema import register_schema_models from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required, logger = logging.getLogger(__name__) +class OAuthDataSourceResponse(BaseModel): + data: str = Field(description="Authorization URL or 'internal' for internal setup") + + +class OAuthDataSourceBindingResponse(BaseModel): + result: str = Field(description="Operation result") + + +class OAuthDataSourceSyncResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + OAuthDataSourceResponse, + OAuthDataSourceBindingResponse, + OAuthDataSourceSyncResponse, +) + + def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( @@ -34,10 +56,7 @@ class OAuthDataSource(Resource): @console_ns.response( 200, "Authorization URL or internal setup success", - console_ns.model( - "OAuthDataSourceResponse", - {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, - ), + console_ns.models[OAuthDataSourceResponse.__name__], ) @console_ns.response(400, "Invalid provider") @console_ns.response(403, "Admin privileges required") @@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource): @console_ns.response( 200, "Data source binding success", - console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceBindingResponse.__name__], ) @console_ns.response(400, "Invalid provider or code") def get(self, provider: str): @@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource): @console_ns.response( 200, "Data source sync success", - console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceSyncResponse.__name__], ) @console_ns.response(400, "Invalid provider or sync failed") @setup_required diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 394f205d93..1ed931b0d7 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,10 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailCodeError, @@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel): return valid_password(value) -for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ForgotPasswordEmailResponse(BaseModel): + result: str = Field(description="Operation result") + data: str | None = Field(default=None, description="Reset token") + code: str | None = Field(default=None, description="Error code if account not found") + + +class ForgotPasswordCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether code is valid") + email: EmailStr = Field(description="Email address") + token: str = Field(description="New reset token") + + +class ForgotPasswordResetResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + ForgotPasswordSendPayload, + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordEmailResponse, + ForgotPasswordCheckResponse, + ForgotPasswordResetResponse, +) @console_ns.route("/forgot-password") @@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): @console_ns.response( 200, "Email sent successfully", - console_ns.model( - "ForgotPasswordEmailResponse", - { - "result": fields.String(description="Operation result"), - "data": fields.String(description="Reset token"), - "code": fields.String(description="Error code if account not found"), - }, - ), + console_ns.models[ForgotPasswordEmailResponse.__name__], ) @console_ns.response(400, "Invalid email or rate limit exceeded") @setup_required @@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource): @console_ns.response( 200, "Code verified successfully", - console_ns.model( - "ForgotPasswordCheckResponse", - { - "is_valid": fields.Boolean(description="Whether code is valid"), - "email": fields.String(description="Email address"), - "token": fields.String(description="New reset token"), - }, - ), + console_ns.models[ForgotPasswordCheckResponse.__name__], ) @console_ns.response(400, "Invalid code or token") @setup_required @@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource): @console_ns.response( 200, "Password reset successfully", - console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[ForgotPasswordResetResponse.__name__], ) @console_ns.response(400, "Invalid token or password mismatch") @setup_required diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index bfd9fc6c29..1897cbdca7 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,9 +1,10 @@ from typing import Any from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery): plugin_id: str +class EndpointCreateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class PluginEndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class EndpointDeleteResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointUpdateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointEnableResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointDisableResponse(BaseModel): + success: bool = Field(description="Operation success") + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -reg(EndpointCreatePayload) -reg(EndpointIdPayload) -reg(EndpointUpdatePayload) -reg(EndpointListQuery) -reg(EndpointListForPluginQuery) +register_schema_models( + console_ns, + EndpointCreatePayload, + EndpointIdPayload, + EndpointUpdatePayload, + EndpointListQuery, + EndpointListForPluginQuery, + EndpointCreateResponse, + EndpointListResponse, + PluginEndpointListResponse, + EndpointDeleteResponse, + EndpointUpdateResponse, + EndpointEnableResponse, + EndpointDisableResponse, +) @console_ns.route("/workspaces/current/endpoints/create") @@ -57,7 +96,7 @@ class EndpointCreateApi(Resource): @console_ns.response( 200, "Endpoint created successfully", - console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointCreateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -91,9 +130,7 @@ class EndpointListApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[EndpointListResponse.__name__], ) @setup_required @login_required @@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[PluginEndpointListResponse.__name__], ) @setup_required @login_required @@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource): @console_ns.response( 200, "Endpoint deleted successfully", - console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDeleteResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource): @console_ns.response( 200, "Endpoint updated successfully", - console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointUpdateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -221,7 +256,7 @@ class EndpointEnableApi(Resource): @console_ns.response( 200, "Endpoint enabled successfully", - console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointEnableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -248,7 +283,7 @@ class EndpointDisableApi(Resource): @console_ns.response( 200, "Endpoint disabled successfully", - console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDisableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index c1608f58a5..18eed4e481 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage): """ content = self.load_once(filename) - with Path(target_filepath).open("wb") as f: - f.write(content) + Path(target_filepath).write_bytes(content) logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)