diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index da9282cd0c..7aa1e6dbd8 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,7 +3,8 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized @@ -18,6 +19,30 @@ from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, InstalledApp, RecommendedApp +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class InsertExploreAppPayload(BaseModel): + app_id: str = Field(...) + desc: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + language: str = Field(...) + category: str = Field(...) + position: int = Field(...) + + @field_validator("language") + @classmethod + def validate_language(cls, value: str) -> str: + return supported_language(value) + + +console_ns.schema_model( + InsertExploreAppPayload.__name__, + InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + def admin_required(view: Callable[P, R]): @wraps(view) @@ -40,59 +65,34 @@ def admin_required(view: Callable[P, R]): class InsertExploreAppListApi(Resource): @console_ns.doc("insert_explore_app") @console_ns.doc(description="Insert or update an app in the explore list") - @console_ns.expect( - console_ns.model( - "InsertExploreAppRequest", - { - "app_id": fields.String(required=True, description="Application ID"), - "desc": fields.String(description="App description"), - "copyright": fields.String(description="Copyright information"), - "privacy_policy": fields.String(description="Privacy policy"), - "custom_disclaimer": fields.String(description="Custom disclaimer"), - "language": fields.String(required=True, description="Language code"), - "category": fields.String(required=True, description="App category"), - "position": fields.Integer(required=True, description="Display position"), - }, - ) - ) + @console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__]) @console_ns.response(200, "App updated successfully") @console_ns.response(201, "App inserted successfully") @console_ns.response(404, "App not found") @only_edition_cloud @admin_required def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("app_id", type=str, required=True, nullable=False, location="json") - .add_argument("desc", type=str, location="json") - .add_argument("copyright", type=str, location="json") - .add_argument("privacy_policy", type=str, location="json") - .add_argument("custom_disclaimer", type=str, location="json") - .add_argument("language", type=supported_language, required=True, nullable=False, location="json") - .add_argument("category", type=str, required=True, nullable=False, location="json") - .add_argument("position", type=int, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + payload = InsertExploreAppPayload.model_validate(console_ns.payload) - app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() + app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none() if not app: - raise NotFound(f"App '{args['app_id']}' is not found") + raise NotFound(f"App '{payload.app_id}' is not found") site = app.site if not site: - desc = args["desc"] or "" - copy_right = args["copyright"] or "" - privacy_policy = args["privacy_policy"] or "" - custom_disclaimer = args["custom_disclaimer"] or "" + desc = payload.desc or "" + copy_right = payload.copyright or "" + privacy_policy = payload.privacy_policy or "" + custom_disclaimer = payload.custom_disclaimer or "" else: - desc = site.description or args["desc"] or "" - copy_right = site.copyright or args["copyright"] or "" - privacy_policy = site.privacy_policy or args["privacy_policy"] or "" - custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" + desc = site.description or payload.desc or "" + copy_right = site.copyright or payload.copyright or "" + privacy_policy = site.privacy_policy or payload.privacy_policy or "" + custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or "" with Session(db.engine) as session: recommended_app = session.execute( - select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) + select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id) ).scalar_one_or_none() if not recommended_app: @@ -102,9 +102,9 @@ class InsertExploreAppListApi(Resource): copyright=copy_right, privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, - language=args["language"], - category=args["category"], - position=args["position"], + language=payload.language, + category=payload.category, + position=payload.position, ) db.session.add(recommended_app) @@ -118,9 +118,9 @@ class InsertExploreAppListApi(Resource): recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args["language"] - recommended_app.category = args["category"] - recommended_app.position = args["position"] + recommended_app.language = payload.language + recommended_app.category = payload.category + recommended_app.position = payload.position app.is_public = True diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index 7e31d0a844..cfdb9cf417 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,6 @@ -from flask_restx import Resource, fields, reqparse +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.app.wraps import get_app_model @@ -8,10 +10,21 @@ from libs.login import login_required from models.model import AppMode from services.agent_service import AgentService -parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID") - .add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID") +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AgentLogQuery(BaseModel): + message_id: str = Field(..., description="Message UUID") + conversation_id: str = Field(..., description="Conversation UUID") + + @field_validator("message_id", "conversation_id") + @classmethod + def validate_uuid(cls, value: str) -> str: + return uuid_value(value) + + +console_ns.schema_model( + AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @@ -20,7 +33,7 @@ class AgentLogApi(Resource): @console_ns.doc("get_agent_logs") @console_ns.doc(description="Get agent execution logs for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[AgentLogQuery.__name__]) @console_ns.response( 200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")) ) @@ -31,6 +44,6 @@ class AgentLogApi(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT]) def get(self, app_model): """Get agent logs""" - args = parser.parse_args() + args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) + return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index edf0cc2cec..3b6fb58931 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,7 +1,8 @@ -from typing import Literal +from typing import Any, Literal from flask import request -from flask_restx import Resource, fields, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field, field_validator from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import console_ns @@ -21,22 +22,79 @@ from libs.helper import uuid_value from libs.login import login_required from services.annotation_service import AppAnnotationService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AnnotationReplyPayload(BaseModel): + score_threshold: float = Field(..., description="Score threshold for annotation matching") + embedding_provider_name: str = Field(..., description="Embedding provider name") + embedding_model_name: str = Field(..., description="Embedding model name") + + +class AnnotationSettingUpdatePayload(BaseModel): + score_threshold: float = Field(..., description="Score threshold") + + +class AnnotationListQuery(BaseModel): + page: int = Field(default=1, ge=1, description="Page number") + limit: int = Field(default=20, ge=1, description="Page size") + keyword: str = Field(default="", description="Search keyword") + + +class CreateAnnotationPayload(BaseModel): + message_id: str | None = Field(default=None, description="Message ID") + question: str | None = Field(default=None, description="Question text") + answer: str | None = Field(default=None, description="Answer text") + content: str | None = Field(default=None, description="Content text") + annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data") + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class UpdateAnnotationPayload(BaseModel): + question: str | None = None + answer: str | None = None + content: str | None = None + annotation_reply: dict[str, Any] | None = None + + +class AnnotationReplyStatusQuery(BaseModel): + action: Literal["enable", "disable"] + + +class AnnotationFilePayload(BaseModel): + message_id: str = Field(..., description="Message ID") + + @field_validator("message_id") + @classmethod + def validate_message_id(cls, value: str) -> str: + return uuid_value(value) + + +def reg(model: type[BaseModel]) -> None: + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(AnnotationReplyPayload) +reg(AnnotationSettingUpdatePayload) +reg(AnnotationListQuery) +reg(CreateAnnotationPayload) +reg(UpdateAnnotationPayload) +reg(AnnotationReplyStatusQuery) +reg(AnnotationFilePayload) + @console_ns.route("/apps//annotation-reply/") class AnnotationReplyActionApi(Resource): @console_ns.doc("annotation_reply_action") @console_ns.doc(description="Enable or disable annotation reply for an app") @console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) - @console_ns.expect( - console_ns.model( - "AnnotationReplyActionRequest", - { - "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"), - "embedding_provider_name": fields.String(required=True, description="Embedding provider name"), - "embedding_model_name": fields.String(required=True, description="Embedding model name"), - }, - ) - ) + @console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__]) @console_ns.response(200, "Action completed successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -46,15 +104,9 @@ class AnnotationReplyActionApi(Resource): @edit_permission_required def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) - parser = ( - reqparse.RequestParser() - .add_argument("score_threshold", required=True, type=float, location="json") - .add_argument("embedding_provider_name", required=True, type=str, location="json") - .add_argument("embedding_model_name", required=True, type=str, location="json") - ) - args = parser.parse_args() + args = AnnotationReplyPayload.model_validate(console_ns.payload) if action == "enable": - result = AppAnnotationService.enable_app_annotation(args, app_id) + result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -82,16 +134,7 @@ class AppAnnotationSettingUpdateApi(Resource): @console_ns.doc("update_annotation_setting") @console_ns.doc(description="Update annotation settings for an app") @console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) - @console_ns.expect( - console_ns.model( - "AnnotationSettingUpdateRequest", - { - "score_threshold": fields.Float(required=True, description="Score threshold"), - "embedding_provider_name": fields.String(required=True, description="Embedding provider"), - "embedding_model_name": fields.String(required=True, description="Embedding model"), - }, - ) - ) + @console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__]) @console_ns.response(200, "Settings updated successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -102,10 +145,9 @@ class AppAnnotationSettingUpdateApi(Resource): app_id = str(app_id) annotation_setting_id = str(annotation_setting_id) - parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json") - args = parser.parse_args() + args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) + result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump()) return result, 200 @@ -142,12 +184,7 @@ class AnnotationApi(Resource): @console_ns.doc("list_annotations") @console_ns.doc(description="Get annotations for an app with pagination") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser() - .add_argument("page", type=int, location="args", default=1, help="Page number") - .add_argument("limit", type=int, location="args", default=20, help="Page size") - .add_argument("keyword", type=str, location="args", default="", help="Search keyword") - ) + @console_ns.expect(console_ns.models[AnnotationListQuery.__name__]) @console_ns.response(200, "Annotations retrieved successfully") @console_ns.response(403, "Insufficient permissions") @setup_required @@ -155,9 +192,10 @@ class AnnotationApi(Resource): @account_initialization_required @edit_permission_required def get(self, app_id): - page = request.args.get("page", default=1, type=int) - limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default="", type=str) + args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + page = args.page + limit = args.limit + keyword = args.keyword app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) @@ -173,18 +211,7 @@ class AnnotationApi(Resource): @console_ns.doc("create_annotation") @console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "CreateAnnotationRequest", - { - "message_id": fields.String(description="Message ID (optional)"), - "question": fields.String(description="Question text (required when message_id not provided)"), - "answer": fields.String(description="Answer text (use 'answer' or 'content')"), - "content": fields.String(description="Content text (use 'answer' or 'content')"), - "annotation_reply": fields.Raw(description="Annotation reply data"), - }, - ) - ) + @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -195,16 +222,9 @@ class AnnotationApi(Resource): @edit_permission_required def post(self, app_id): app_id = str(app_id) - parser = ( - reqparse.RequestParser() - .add_argument("message_id", required=False, type=uuid_value, location="json") - .add_argument("question", required=False, type=str, location="json") - .add_argument("answer", required=False, type=str, location="json") - .add_argument("content", required=False, type=str, location="json") - .add_argument("annotation_reply", required=False, type=dict, location="json") - ) - args = parser.parse_args() - annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) + args = CreateAnnotationPayload.model_validate(console_ns.payload) + data = args.model_dump(exclude_none=True) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) return annotation @setup_required @@ -256,13 +276,6 @@ class AnnotationExportApi(Resource): return response, 200 -parser = ( - reqparse.RequestParser() - .add_argument("question", required=True, type=str, location="json") - .add_argument("answer", required=True, type=str, location="json") -) - - @console_ns.route("/apps//annotations/") class AnnotationUpdateDeleteApi(Resource): @console_ns.doc("update_delete_annotation") @@ -271,7 +284,7 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) @console_ns.response(204, "Annotation deleted successfully") @console_ns.response(403, "Insufficient permissions") - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -281,8 +294,10 @@ class AnnotationUpdateDeleteApi(Resource): def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) - args = parser.parse_args() - annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) + args = UpdateAnnotationPayload.model_validate(console_ns.payload) + annotation = AppAnnotationService.update_app_annotation_directly( + args.model_dump(exclude_none=True), app_id, annotation_id + ) return annotation @setup_required diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 1b02edd489..22e2aeb720 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,4 +1,5 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from controllers.console.app.wraps import get_app_model @@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model( "AppImportCheckDependencies", app_import_check_dependencies_fields_copy ) -parser = ( - reqparse.RequestParser() - .add_argument("mode", type=str, required=True, location="json") - .add_argument("yaml_content", type=str, location="json") - .add_argument("yaml_url", type=str, location="json") - .add_argument("name", type=str, location="json") - .add_argument("description", type=str, location="json") - .add_argument("icon_type", type=str, location="json") - .add_argument("icon", type=str, location="json") - .add_argument("icon_background", type=str, location="json") - .add_argument("app_id", type=str, location="json") +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AppImportPayload(BaseModel): + mode: str = Field(..., description="Import mode") + yaml_content: str | None = None + yaml_url: str | None = None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + app_id: str | None = None + + +console_ns.schema_model( + AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @console_ns.route("/apps/imports") class AppImportApi(Resource): - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[AppImportPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -61,7 +68,7 @@ class AppImportApi(Resource): def post(self): # Check user role first current_user, _ = current_account_with_tenant() - args = parser.parse_args() + args = AppImportPayload.model_validate(console_ns.payload) # Create service with session with Session(db.engine) as session: @@ -70,15 +77,15 @@ class AppImportApi(Resource): account = current_user result = import_service.import_app( account=account, - import_mode=args["mode"], - yaml_content=args.get("yaml_content"), - yaml_url=args.get("yaml_url"), - name=args.get("name"), - description=args.get("description"), - icon_type=args.get("icon_type"), - icon=args.get("icon"), - icon_background=args.get("icon_background"), - app_id=args.get("app_id"), + import_mode=args.mode, + yaml_content=args.yaml_content, + yaml_url=args.yaml_url, + name=args.name, + description=args.description, + icon_type=args.icon_type, + icon=args.icon, + icon_background=args.icon_background, + app_id=args.app_id, ) session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 86446f1164..d344ede466 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services @@ -32,6 +33,27 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class TextToSpeechPayload(BaseModel): + message_id: str | None = Field(default=None, description="Message ID") + text: str = Field(..., description="Text to convert") + voice: str | None = Field(default=None, description="Voice name") + streaming: bool | None = Field(default=None, description="Whether to stream audio") + + +class TextToSpeechVoiceQuery(BaseModel): + language: str = Field(..., description="Language code") + + +console_ns.schema_model( + TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + TextToSpeechVoiceQuery.__name__, + TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) @console_ns.route("/apps//audio-to-text") @@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource): @console_ns.doc("chat_message_text_to_speech") @console_ns.doc(description="Convert text to speech for chat messages") @console_ns.doc(params={"app_id": "App ID"}) - @console_ns.expect( - console_ns.model( - "TextToSpeechRequest", - { - "message_id": fields.String(description="Message ID"), - "text": fields.String(required=True, description="Text to convert to speech"), - "voice": fields.String(description="Voice to use for TTS"), - "streaming": fields.Boolean(description="Whether to stream the audio"), - }, - ) - ) + @console_ns.expect(console_ns.models[TextToSpeechPayload.__name__]) @console_ns.response(200, "Text to speech conversion successful") @console_ns.response(400, "Bad request - Invalid parameters") @get_app_model @@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource): @account_initialization_required def post(self, app_model: App): try: - parser = ( - reqparse.RequestParser() - .add_argument("message_id", type=str, location="json") - .add_argument("text", type=str, location="json") - .add_argument("voice", type=str, location="json") - .add_argument("streaming", type=bool, location="json") - ) - args = parser.parse_args() - - message_id = args.get("message_id", None) - text = args.get("text", None) - voice = args.get("voice", None) + payload = TextToSpeechPayload.model_validate(console_ns.payload) response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True + app_model=app_model, + text=payload.text, + voice=payload.voice, + message_id=payload.message_id, + is_draft=True, ) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -159,9 +164,7 @@ class TextModesApi(Resource): @console_ns.doc("get_text_to_speech_voices") @console_ns.doc(description="Get available TTS voices for a specific language") @console_ns.doc(params={"app_id": "App ID"}) - @console_ns.expect( - console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code") - ) + @console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__]) @console_ns.response( 200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")) ) @@ -172,12 +175,11 @@ class TextModesApi(Resource): @account_initialization_required def get(self, app_model): try: - parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args") - args = parser.parse_args() + args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, - language=args["language"], + language=args.language, ) return response diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 58d1fb4a2d..dd982b6d7b 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,7 +1,8 @@ import json from enum import StrEnum -from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields from libs.login import current_account_with_tenant, login_required from models.model import AppMCPServer +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + # Register model for flask_restx to avoid dict type issues in Swagger app_server_model = console_ns.model("AppServer", app_server_fields) @@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum): INACTIVE = "inactive" +class MCPServerCreatePayload(BaseModel): + description: str | None = Field(default=None, description="Server description") + parameters: dict = Field(..., description="Server parameters configuration") + + +class MCPServerUpdatePayload(BaseModel): + id: str = Field(..., description="Server ID") + description: str | None = Field(default=None, description="Server description") + parameters: dict = Field(..., description="Server parameters configuration") + status: str | None = Field(default=None, description="Server status") + + +for model in (MCPServerCreatePayload, MCPServerUpdatePayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + @console_ns.route("/apps//server") class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @@ -39,15 +58,7 @@ class AppMCPServerController(Resource): @console_ns.doc("create_app_mcp_server") @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "MCPServerCreateRequest", - { - "description": fields.String(description="Server description"), - "parameters": fields.Raw(required=True, description="Server parameters configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) @console_ns.response(201, "MCP server configuration created successfully", app_server_model) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @@ -58,21 +69,16 @@ class AppMCPServerController(Resource): @edit_permission_required def post(self, app_model): _, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("description", type=str, required=False, location="json") - .add_argument("parameters", type=dict, required=True, location="json") - ) - args = parser.parse_args() + payload = MCPServerCreatePayload.model_validate(console_ns.payload or {}) - description = args.get("description") + description = payload.description if not description: description = app_model.description or "" server = AppMCPServer( name=app_model.name, description=description, - parameters=json.dumps(args["parameters"], ensure_ascii=False), + parameters=json.dumps(payload.parameters, ensure_ascii=False), status=AppMCPServerStatus.ACTIVE, app_id=app_model.id, tenant_id=current_tenant_id, @@ -85,17 +91,7 @@ class AppMCPServerController(Resource): @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "MCPServerUpdateRequest", - { - "id": fields.String(required=True, description="Server ID"), - "description": fields.String(description="Server description"), - "parameters": fields.Raw(required=True, description="Server parameters configuration"), - "status": fields.String(description="Server status"), - }, - ) - ) + @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) @console_ns.response(200, "MCP server configuration updated successfully", app_server_model) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @@ -106,19 +102,12 @@ class AppMCPServerController(Resource): @marshal_with(app_server_model) @edit_permission_required def put(self, app_model): - parser = ( - reqparse.RequestParser() - .add_argument("id", type=str, required=True, location="json") - .add_argument("description", type=str, required=False, location="json") - .add_argument("parameters", type=dict, required=True, location="json") - .add_argument("status", type=str, required=False, location="json") - ) - args = parser.parse_args() - server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() + payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) + server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first() if not server: raise NotFound() - description = args.get("description") + description = payload.description if description is None: pass elif not description: @@ -126,11 +115,11 @@ class AppMCPServerController(Resource): else: server.description = description - server.parameters = json.dumps(args["parameters"], ensure_ascii=False) - if args["status"]: - if args["status"] not in [status.value for status in AppMCPServerStatus]: + server.parameters = json.dumps(payload.parameters, ensure_ascii=False) + if payload.status: + if payload.status not in [status.value for status in AppMCPServerStatus]: raise ValueError("Invalid status") - server.status = args["status"] + server.status = payload.status db.session.commit() return server diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 19c1a11258..cbcf513162 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,8 @@ -from flask_restx import Resource, fields, reqparse +from typing import Any + +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest from controllers.console import console_ns @@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req from libs.login import login_required from services.ops_service import OpsService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class TraceProviderQuery(BaseModel): + tracing_provider: str = Field(..., description="Tracing provider name") + + +class TraceConfigPayload(BaseModel): + tracing_provider: str = Field(..., description="Tracing provider name") + tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data") + + +console_ns.schema_model( + TraceProviderQuery.__name__, + TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + @console_ns.route("/apps//trace-config") class TraceAppConfigApi(Resource): @@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("get_trace_app_config") @console_ns.doc(description="Get tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser().add_argument( - "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" - ) - ) + @console_ns.expect(console_ns.models[TraceProviderQuery.__name__]) @console_ns.response( 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") ) @@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource): @login_required @account_initialization_required def get(self, app_id): - parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") - args = parser.parse_args() + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) + trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) if not trace_config: return {"has_not_configured": True} return trace_config @@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("create_trace_app_config") @console_ns.doc(description="Create a new tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "TraceConfigCreateRequest", - { - "tracing_provider": fields.String(required=True, description="Tracing provider name"), - "tracing_config": fields.Raw(required=True, description="Tracing configuration data"), - }, - ) - ) + @console_ns.expect(console_ns.models[TraceConfigPayload.__name__]) @console_ns.response( 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") ) @@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def post(self, app_id): """Create a new trace app configuration""" - parser = ( - reqparse.RequestParser() - .add_argument("tracing_provider", type=str, required=True, location="json") - .add_argument("tracing_config", type=dict, required=True, location="json") - ) - args = parser.parse_args() + args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.create_tracing_app_config( - app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] + app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigIsExist() @@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("update_trace_app_config") @console_ns.doc(description="Update an existing tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "TraceConfigUpdateRequest", - { - "tracing_provider": fields.String(required=True, description="Tracing provider name"), - "tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"), - }, - ) - ) + @console_ns.expect(console_ns.models[TraceConfigPayload.__name__]) @console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) @console_ns.response(400, "Invalid request parameters or configuration not found") @setup_required @@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def patch(self, app_id): """Update an existing trace app configuration""" - parser = ( - reqparse.RequestParser() - .add_argument("tracing_provider", type=str, required=True, location="json") - .add_argument("tracing_config", type=dict, required=True, location="json") - ) - args = parser.parse_args() + args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.update_tracing_app_config( - app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] + app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigNotExist() @@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource): @console_ns.doc("delete_trace_app_config") @console_ns.doc(description="Delete an existing tracing configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.parser().add_argument( - "tracing_provider", type=str, required=True, location="args", help="Tracing provider name" - ) - ) + @console_ns.expect(console_ns.models[TraceProviderQuery.__name__]) @console_ns.response(204, "Tracing configuration deleted successfully") @console_ns.response(400, "Invalid request parameters or configuration not found") @setup_required @@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource): @account_initialization_required def delete(self, app_id): """Delete an existing trace app configuration""" - parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") - args = parser.parse_args() + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) + result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index d46b8c5c9d..db218d8b81 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,7 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from typing import Literal + +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import NotFound from constants.languages import supported_language @@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import Site +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class AppSiteUpdatePayload(BaseModel): + title: str | None = Field(default=None) + icon_type: str | None = Field(default=None) + icon: str | None = Field(default=None) + icon_background: str | None = Field(default=None) + description: str | None = Field(default=None) + default_language: str | None = Field(default=None) + chat_color_theme: str | None = Field(default=None) + chat_color_theme_inverted: bool | None = Field(default=None) + customize_domain: str | None = Field(default=None) + copyright: str | None = Field(default=None) + privacy_policy: str | None = Field(default=None) + custom_disclaimer: str | None = Field(default=None) + customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None) + prompt_public: bool | None = Field(default=None) + show_workflow_steps: bool | None = Field(default=None) + use_icon_as_answer_icon: bool | None = Field(default=None) + + @field_validator("default_language") + @classmethod + def validate_language(cls, value: str | None) -> str | None: + if value is None: + return value + return supported_language(value) + + +console_ns.schema_model( + AppSiteUpdatePayload.__name__, + AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + # Register model for flask_restx to avoid dict type issues in Swagger app_site_model = console_ns.model("AppSite", app_site_fields) -def parse_app_site_args(): - parser = ( - reqparse.RequestParser() - .add_argument("title", type=str, required=False, location="json") - .add_argument("icon_type", type=str, required=False, location="json") - .add_argument("icon", type=str, required=False, location="json") - .add_argument("icon_background", type=str, required=False, location="json") - .add_argument("description", type=str, required=False, location="json") - .add_argument("default_language", type=supported_language, required=False, location="json") - .add_argument("chat_color_theme", type=str, required=False, location="json") - .add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") - .add_argument("customize_domain", type=str, required=False, location="json") - .add_argument("copyright", type=str, required=False, location="json") - .add_argument("privacy_policy", type=str, required=False, location="json") - .add_argument("custom_disclaimer", type=str, required=False, location="json") - .add_argument( - "customize_token_strategy", - type=str, - choices=["must", "allow", "not_allow"], - required=False, - location="json", - ) - .add_argument("prompt_public", type=bool, required=False, location="json") - .add_argument("show_workflow_steps", type=bool, required=False, location="json") - .add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json") - ) - return parser.parse_args() - - @console_ns.route("/apps//site") class AppSite(Resource): @console_ns.doc("update_app_site") @console_ns.doc(description="Update application site configuration") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "AppSiteRequest", - { - "title": fields.String(description="Site title"), - "icon_type": fields.String(description="Icon type"), - "icon": fields.String(description="Icon"), - "icon_background": fields.String(description="Icon background color"), - "description": fields.String(description="Site description"), - "default_language": fields.String(description="Default language"), - "chat_color_theme": fields.String(description="Chat color theme"), - "chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"), - "customize_domain": fields.String(description="Custom domain"), - "copyright": fields.String(description="Copyright text"), - "privacy_policy": fields.String(description="Privacy policy"), - "custom_disclaimer": fields.String(description="Custom disclaimer"), - "customize_token_strategy": fields.String( - enum=["must", "allow", "not_allow"], description="Token strategy" - ), - "prompt_public": fields.Boolean(description="Make prompt public"), - "show_workflow_steps": fields.Boolean(description="Show workflow steps"), - "use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"), - }, - ) - ) + @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__]) @console_ns.response(200, "Site configuration updated successfully", app_site_model) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "App not found") @@ -89,7 +73,7 @@ class AppSite(Resource): @get_app_model @marshal_with(app_site_model) def post(self, app_model): - args = parse_app_site_args() + args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: @@ -113,7 +97,7 @@ class AppSite(Resource): "show_workflow_steps", "use_icon_as_answer_icon", ]: - value = args.get(attr_name) + value = getattr(args, attr_name) if value is not None: setattr(site, attr_name, value) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 41ae8727de..3382b65acc 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,10 +1,11 @@ import logging from collections.abc import Callable from functools import wraps -from typing import NoReturn, ParamSpec, TypeVar +from typing import Any, NoReturn, ParamSpec, TypeVar -from flask import Response -from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask import Response, request +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field from sqlalchemy.orm import Session from controllers.console import console_ns @@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowDraftVariableListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=100_000, description="Page number") + limit: int = Field(default=20, ge=1, le=100, description="Items per page") + + +class WorkflowDraftVariableUpdatePayload(BaseModel): + name: str | None = Field(default=None, description="Variable name") + value: Any | None = Field(default=None, description="Variable value") + + +console_ns.schema_model( + WorkflowDraftVariableListQuery.__name__, + WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + WorkflowDraftVariableUpdatePayload.__name__, + WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) def _convert_values_to_json_serializable_object(value: Segment): @@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable): return _convert_values_to_json_serializable_object(value) -def _create_pagination_parser(): - parser = ( - reqparse.RequestParser() - .add_argument( - "page", - type=inputs.int_range(1, 100_000), - required=False, - default=1, - location="args", - help="the page of data requested", - ) - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - ) - return parser - - def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type return value_type.exposed_type().value @@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]): @console_ns.route("/apps//workflows/draft/variables") class WorkflowVariableCollectionApi(Resource): - @console_ns.expect(_create_pagination_parser()) + @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__]) @console_ns.doc("get_workflow_variables") @console_ns.doc(description="Get draft workflow variables") @console_ns.doc(params={"app_id": "Application ID"}) @@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource): """ Get draft workflow """ - parser = _create_pagination_parser() - args = parser.parse_args() + args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -323,15 +328,7 @@ class VariableApi(Resource): @console_ns.doc("update_variable") @console_ns.doc(description="Update a workflow variable") - @console_ns.expect( - console_ns.model( - "UpdateVariableRequest", - { - "name": fields.String(description="Variable name"), - "value": fields.Raw(description="Variable value"), - }, - ) - ) + @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__]) @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model) @console_ns.response(404, "Variable not found") @_api_prerequisite @@ -358,16 +355,10 @@ class VariableApi(Resource): # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # } - parser = ( - reqparse.RequestParser() - .add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json") - .add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json") - ) - draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - args = parser.parse_args(strict=True) + args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) variable = draft_var_srv.get_variable(variable_id=variable_id) if variable is None: @@ -375,8 +366,8 @@ class VariableApi(Resource): if variable.app_id != app_model.id: raise NotFoundError(description=f"variable not found, id={variable_id}") - new_name = args.get(self._PATCH_NAME_FIELD, None) - raw_value = args.get(self._PATCH_VALUE_FIELD, None) + new_name = args_model.name + raw_value = args_model.value if new_name is None and raw_value is None: return variable diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index a11b741040..6834656a7f 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,28 +1,53 @@ from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from constants.languages import supported_language from controllers.console import console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from libs.helper import StrLen, email, extract_remote_ip, timezone +from libs.helper import EmailStr, extract_remote_ip, timezone from models import AccountStatus from services.account_service import AccountService, RegisterService -active_check_parser = ( - reqparse.RequestParser() - .add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID") - .add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address") - .add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token") -) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ActivateCheckQuery(BaseModel): + workspace_id: str | None = Field(default=None) + email: EmailStr | None = Field(default=None) + token: str + + +class ActivatePayload(BaseModel): + workspace_id: str | None = Field(default=None) + email: EmailStr | None = Field(default=None) + token: str + name: str = Field(..., max_length=30) + interface_language: str = Field(...) + timezone: str = Field(...) + + @field_validator("interface_language") + @classmethod + def validate_lang(cls, value: str) -> str: + return supported_language(value) + + @field_validator("timezone") + @classmethod + def validate_tz(cls, value: str) -> str: + return timezone(value) + + +for model in (ActivateCheckQuery, ActivatePayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @console_ns.route("/activate/check") class ActivateCheckApi(Resource): @console_ns.doc("check_activation_token") @console_ns.doc(description="Check if activation token is valid") - @console_ns.expect(active_check_parser) + @console_ns.expect(console_ns.models[ActivateCheckQuery.__name__]) @console_ns.response( 200, "Success", @@ -35,11 +60,11 @@ class ActivateCheckApi(Resource): ), ) def get(self): - args = active_check_parser.parse_args() + args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - workspaceId = args["workspace_id"] - reg_email = args["email"] - token = args["token"] + workspaceId = args.workspace_id + reg_email = args.email + token = args.token invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: @@ -56,22 +81,11 @@ class ActivateCheckApi(Resource): return {"is_valid": False} -active_parser = ( - reqparse.RequestParser() - .add_argument("workspace_id", type=str, required=False, nullable=True, location="json") - .add_argument("email", type=email, required=False, nullable=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") - .add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json") - .add_argument("timezone", type=timezone, required=True, nullable=False, location="json") -) - - @console_ns.route("/activate") class ActivateApi(Resource): @console_ns.doc("activate_account") @console_ns.doc(description="Activate account with invitation token") - @console_ns.expect(active_parser) + @console_ns.expect(console_ns.models[ActivatePayload.__name__]) @console_ns.response( 200, "Account activated successfully", @@ -85,19 +99,19 @@ class ActivateApi(Resource): ) @console_ns.response(400, "Already activated or invalid token") def post(self): - args = active_parser.parse_args() + args = ActivatePayload.model_validate(console_ns.payload) - invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) + invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) + RegisterService.revoke_token(args.workspace_id, args.email, args.token) account = invitation["account"] - account.name = args["name"] + account.name = args.name - account.interface_language = args["interface_language"] - account.timezone = args["timezone"] + account.interface_language = args.interface_language + account.timezone = args.timezone account.interface_theme = "light" account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 9d7fcef183..905d0daef0 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,12 +1,26 @@ -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field -from controllers.console import console_ns -from controllers.console.auth.error import ApiKeyAuthFailedError -from controllers.console.wraps import is_admin_or_owner_required from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService -from ..wraps import account_initialization_required, setup_required +from .. import console_ns +from ..auth.error import ApiKeyAuthFailedError +from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ApiKeyAuthBindingPayload(BaseModel): + category: str = Field(...) + provider: str = Field(...) + credentials: dict = Field(...) + + +console_ns.schema_model( + ApiKeyAuthBindingPayload.__name__, + ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) @console_ns.route("/api-key-auth/data-source") @@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource): @login_required @account_initialization_required @is_admin_or_owner_required + @console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__]) def post(self): # The role of the current user in the table must be admin or owner _, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("category", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - ) - args = parser.parse_args() - ApiKeyAuthService.validate_api_key_auth_args(args) + payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload) + data = payload.model_dump() + ApiKeyAuthService.validate_api_key_auth_args(data) try: - ApiKeyAuthService.create_provider_auth(current_tenant_id, args) + ApiKeyAuthService.create_provider_auth(current_tenant_id, data) except Exception as e: raise ApiKeyAuthFailedError(str(e)) return {"result": "success"}, 200 diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index cd547caf20..0dd7d33ae9 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -5,12 +5,11 @@ from flask import current_app, redirect, request from flask_restx import Resource, fields from configs import dify_config -from controllers.console import console_ns -from controllers.console.wraps import is_admin_or_owner_required from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from ..wraps import account_initialization_required, setup_required +from .. import console_ns +from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required logger = logging.getLogger(__name__) diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index fe2bb54e0b..fa082c735d 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,5 +1,6 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,16 +15,45 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError -from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required from extensions.ext_database import db -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import AccountNotFoundError, AccountRegisterError +from ..error import AccountInFreezeError, EmailSendIpLimitError +from ..wraps import email_password_login_enabled, email_register_enabled, setup_required + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class EmailRegisterSendPayload(BaseModel): + email: EmailStr = Field(..., description="Email address") + language: str | None = Field(default=None, description="Language code") + + +class EmailRegisterValidityPayload(BaseModel): + email: EmailStr = Field(...) + code: str = Field(...) + token: str = Field(...) + + +class EmailRegisterResetPayload(BaseModel): + token: str = Field(...) + new_password: str = Field(...) + password_confirm: str = Field(...) + + @field_validator("new_password", "password_confirm") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + @console_ns.route("/email-register/send-email") class EmailRegisterSendEmailApi(Resource): @@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailRegisterSendPayload.model_validate(console_ns.payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() language = "en-US" - if args["language"] in languages: - language = args["language"] + if args.language in languages: + language = args.language - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): raise AccountInFreezeError() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() token = None - token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) + token = AccountService.send_email_register_email(email=args.email, account=account, language=language) return {"result": "success", "data": token} @@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = EmailRegisterValidityPayload.model_validate(console_ns.payload) - user_email = args["email"] + user_email = args.email - is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email) if is_email_register_error_rate_limit: raise EmailRegisterLimitError() - token_data = AccountService.get_email_register_data(args["token"]) + token_data = AccountService.get_email_register_data(args.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): - AccountService.add_email_register_error_rate_limit(args["email"]) + if args.code != token_data.get("code"): + AccountService.add_email_register_error_rate_limit(args.email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_email_register_token(args["token"]) + AccountService.revoke_email_register_token(args.token) # Refresh token data by generating a new token _, new_token = AccountService.generate_email_register_token( - user_email, code=args["code"], additional_data={"phase": "register"} + user_email, code=args.code, additional_data={"phase": "register"} ) - AccountService.reset_email_register_error_rate_limit(args["email"]) + AccountService.reset_email_register_error_rate_limit(args.email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource): @email_password_login_enabled @email_register_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = EmailRegisterResetPayload.model_validate(console_ns.payload) # Validate passwords match - if args["new_password"] != args["password_confirm"]: + if args.new_password != args.password_confirm: raise PasswordMismatchError() # Validate token and get register data - register_data = AccountService.get_email_register_data(args["token"]) + register_data = AccountService.get_email_register_data(args.token) if not register_data: raise InvalidTokenError() # Must use token in reset phase @@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource): raise InvalidTokenError() # Revoke token to prevent reuse - AccountService.revoke_email_register_token(args["token"]) + AccountService.revoke_email_register_token(args.token) email = register_data.get("email", "") @@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource): if account: raise EmailAlreadyInUseError() else: - account = self._create_new_account(email, args["password_confirm"]) + account = self._create_new_account(email, args.password_confirm) if not account: raise AccountNotFoundError() token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ee561bdd30..661f591182 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,7 +2,8 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password from models import Account from services.account_service import AccountService, TenantService from services.feature_service import FeatureService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ForgotPasswordSendPayload(BaseModel): + email: EmailStr = Field(...) + language: str | None = Field(default=None) + + +class ForgotPasswordCheckPayload(BaseModel): + email: EmailStr = Field(...) + code: str = Field(...) + token: str = Field(...) + + +class ForgotPasswordResetPayload(BaseModel): + token: str = Field(...) + new_password: str = Field(...) + password_confirm: str = Field(...) + + @field_validator("new_password", "password_confirm") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + @console_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): @console_ns.doc("send_forgot_password_email") @console_ns.doc(description="Send password reset email") - @console_ns.expect( - console_ns.model( - "ForgotPasswordEmailRequest", - { - "email": fields.String(required=True, description="Email address"), - "language": fields.String(description="Language for email (zh-Hans/en-US)"), - }, - ) - ) + @console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__]) @console_ns.response( 200, "Email sent successfully", @@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = ForgotPasswordSendPayload.model_validate(console_ns.payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() token = AccountService.send_reset_password_email( account=account, - email=args["email"], + email=args.email, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, ) @@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): @console_ns.doc("check_forgot_password_code") @console_ns.doc(description="Verify password reset code") - @console_ns.expect( - console_ns.model( - "ForgotPasswordCheckRequest", - { - "email": fields.String(required=True, description="Email address"), - "code": fields.String(required=True, description="Verification code"), - "token": fields.String(required=True, description="Reset token"), - }, - ) - ) + @console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__]) @console_ns.response( 200, "Code verified successfully", @@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = ForgotPasswordCheckPayload.model_validate(console_ns.payload) - user_email = args["email"] + user_email = args.email - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() - token_data = AccountService.get_reset_password_data(args["token"]) + token_data = AccountService.get_reset_password_data(args.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args["email"]) + if args.code != token_data.get("code"): + AccountService.add_forgot_password_error_rate_limit(args.email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(args.token) # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=args["code"], additional_data={"phase": "reset"} + user_email, code=args.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args["email"]) + AccountService.reset_forgot_password_error_rate_limit(args.email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} @@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource): class ForgotPasswordResetApi(Resource): @console_ns.doc("reset_password") @console_ns.doc(description="Reset password with verification token") - @console_ns.expect( - console_ns.model( - "ForgotPasswordResetRequest", - { - "token": fields.String(required=True, description="Verification token"), - "new_password": fields.String(required=True, description="New password"), - "password_confirm": fields.String(required=True, description="Password confirmation"), - }, - ) - ) + @console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__]) @console_ns.response( 200, "Password reset successfully", @@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource): @setup_required @email_password_login_enabled def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, nullable=False, location="json") - .add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") - .add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") - ) - args = parser.parse_args() + args = ForgotPasswordResetPayload.model_validate(console_ns.payload) # Validate passwords match - if args["new_password"] != args["password_confirm"]: + if args.new_password != args.password_confirm: raise PasswordMismatchError() # Validate token and get reset data - reset_data = AccountService.get_reset_password_data(args["token"]) + reset_data = AccountService.get_reset_password_data(args.token) if not reset_data: raise InvalidTokenError() # Must use token in reset phase @@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource): raise InvalidTokenError() # Revoke token to prevent reuse - AccountService.revoke_reset_password_token(args["token"]) + AccountService.revoke_reset_password_token(args.token) # Generate secure salt and hash password salt = secrets.token_bytes(16) - password_hashed = hash_password(args["new_password"], salt) + password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 77ecd5a5e4..f486f4c313 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,6 +1,7 @@ import flask_login from flask import make_response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field import services from configs import dify_config @@ -23,7 +24,7 @@ from controllers.console.error import ( ) from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created -from libs.helper import email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.login import current_account_with_tenant from libs.token import ( clear_access_token_from_cookie, @@ -40,6 +41,36 @@ from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class LoginPayload(BaseModel): + email: EmailStr = Field(..., description="Email address") + password: str = Field(..., description="Password") + remember_me: bool = Field(default=False, description="Remember me flag") + invite_token: str | None = Field(default=None, description="Invitation token") + + +class EmailPayload(BaseModel): + email: EmailStr = Field(...) + language: str | None = Field(default=None) + + +class EmailCodeLoginPayload(BaseModel): + email: EmailStr = Field(...) + code: str = Field(...) + token: str = Field(...) + language: str | None = Field(default=None) + + +def reg(cls: type[BaseModel]): + console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + + +reg(LoginPayload) +reg(EmailPayload) +reg(EmailCodeLoginPayload) + @console_ns.route("/login") class LoginApi(Resource): @@ -47,41 +78,36 @@ class LoginApi(Resource): @setup_required @email_password_login_enabled + @console_ns.expect(console_ns.models[LoginPayload.__name__]) def post(self): """Authenticate user and login.""" - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("password", type=str, required=True, location="json") - .add_argument("remember_me", type=bool, required=False, default=False, location="json") - .add_argument("invite_token", type=str, required=False, default=None, location="json") - ) - args = parser.parse_args() + args = LoginPayload.model_validate(console_ns.payload) - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): raise AccountInFreezeError() - is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() - invitation = args["invite_token"] + # TODO: why invitation is re-assigned with different type? + invitation = args.invite_token # type: ignore if invitation: - invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) + invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore try: if invitation: - data = invitation.get("data", {}) + data = invitation.get("data", {}) # type: ignore invitee_email = data.get("email") if data else None - if invitee_email != args["email"]: + if invitee_email != args.email: raise InvalidEmailError() - account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) + account = AccountService.authenticate(args.email, args.password, args.invite_token) else: - account = AccountService.authenticate(args["email"], args["password"]) + account = AccountService.authenticate(args.email, args.password) except services.errors.account.AccountLoginError: raise AccountBannedError() except services.errors.account.AccountPasswordError: - AccountService.add_login_error_rate_limit(args["email"]) + AccountService.add_login_error_rate_limit(args.email) raise AuthenticationFailedError() # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) @@ -97,7 +123,7 @@ class LoginApi(Resource): } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(args.email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -134,25 +160,21 @@ class LogoutApi(Resource): class ResetPasswordSendEmailApi(Resource): @setup_required @email_password_login_enabled + @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailPayload.model_validate(console_ns.payload) - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" try: - account = AccountService.get_user_through_email(args["email"]) + account = AccountService.get_user_through_email(args.email) except AccountRegisterError: raise AccountInFreezeError() token = AccountService.send_reset_password_email( - email=args["email"], + email=args.email, account=account, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, @@ -164,30 +186,26 @@ class ResetPasswordSendEmailApi(Resource): @console_ns.route("/email-code-login") class EmailCodeLoginSendEmailApi(Resource): @setup_required + @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailPayload.model_validate(console_ns.payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" try: - account = AccountService.get_user_through_email(args["email"]) + account = AccountService.get_user_through_email(args.email) except AccountRegisterError: raise AccountInFreezeError() if account is None: if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_email_code_login_email(email=args["email"], language=language) + token = AccountService.send_email_code_login_email(email=args.email, language=language) else: raise AccountNotFound() else: @@ -199,30 +217,24 @@ class EmailCodeLoginSendEmailApi(Resource): @console_ns.route("/email-code-login/validity") class EmailCodeLoginApi(Resource): @setup_required + @console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__]) def post(self): - parser = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = EmailCodeLoginPayload.model_validate(console_ns.payload) - user_email = args["email"] - language = args["language"] + user_email = args.email + language = args.language - token_data = AccountService.get_email_code_login_data(args["token"]) + token_data = AccountService.get_email_code_login_data(args.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: + if token_data["email"] != args.email: raise InvalidEmailError() - if token_data["code"] != args["code"]: + if token_data["code"] != args.code: raise EmailCodeError() - AccountService.revoke_email_code_login_token(args["token"]) + AccountService.revoke_email_code_login_token(args.token) try: account = AccountService.get_user_through_email(user_email) except AccountRegisterError: @@ -255,7 +267,7 @@ class EmailCodeLoginApi(Resource): except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(args.email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 5e12aa7d03..6162d88a0b 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -3,7 +3,8 @@ from functools import wraps from typing import Concatenate, ParamSpec, TypeVar from flask import jsonify, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required @@ -20,15 +21,34 @@ R = TypeVar("R") T = TypeVar("T") +class OAuthClientPayload(BaseModel): + client_id: str + + +class OAuthProviderRequest(BaseModel): + client_id: str + redirect_uri: str + + +class OAuthTokenRequest(BaseModel): + client_id: str + grant_type: str + code: str | None = None + client_secret: str | None = None + redirect_uri: str | None = None + refresh_token: str | None = None + + def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): @wraps(view) def decorated(self: T, *args: P.args, **kwargs: P.kwargs): - parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json") - parsed_args = parser.parse_args() - client_id = parsed_args.get("client_id") - if not client_id: + json_data = request.get_json() + if json_data is None: raise BadRequest("client_id is required") + payload = OAuthClientPayload.model_validate(json_data) + client_id = payload.client_id + oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) if not oauth_provider_app: raise NotFound("client_id is invalid") @@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource): @setup_required @oauth_server_client_id_required def post(self, oauth_provider_app: OAuthProviderApp): - parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json") - parsed_args = parser.parse_args() - redirect_uri = parsed_args.get("redirect_uri") + payload = OAuthProviderRequest.model_validate(request.get_json()) + redirect_uri = payload.redirect_uri # check if redirect_uri is valid if redirect_uri not in oauth_provider_app.redirect_uris: @@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource): @setup_required @oauth_server_client_id_required def post(self, oauth_provider_app: OAuthProviderApp): - parser = ( - reqparse.RequestParser() - .add_argument("grant_type", type=str, required=True, location="json") - .add_argument("code", type=str, required=False, location="json") - .add_argument("client_secret", type=str, required=False, location="json") - .add_argument("redirect_uri", type=str, required=False, location="json") - .add_argument("refresh_token", type=str, required=False, location="json") - ) - parsed_args = parser.parse_args() + payload = OAuthTokenRequest.model_validate(request.get_json()) try: - grant_type = OAuthGrantType(parsed_args["grant_type"]) + grant_type = OAuthGrantType(payload.grant_type) except ValueError: raise BadRequest("invalid grant_type") if grant_type == OAuthGrantType.AUTHORIZATION_CODE: - if not parsed_args["code"]: + if not payload.code: raise BadRequest("code is required") - if parsed_args["client_secret"] != oauth_provider_app.client_secret: + if payload.client_secret != oauth_provider_app.client_secret: raise BadRequest("client_secret is invalid") - if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris: + if payload.redirect_uri not in oauth_provider_app.redirect_uris: raise BadRequest("redirect_uri is invalid") access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id + grant_type, code=payload.code, client_id=oauth_provider_app.client_id ) return jsonable_encoder( { @@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource): } ) elif grant_type == OAuthGrantType.REFRESH_TOKEN: - if not parsed_args["refresh_token"]: + if not payload.refresh_token: raise BadRequest("refresh_token is required") access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id + grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id ) return jsonable_encoder( { diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4fef1ba40d..7f907dc420 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,6 +1,8 @@ import base64 -from flask_restx import Resource, fields, reqparse +from flask import request +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest from controllers.console import console_ns @@ -9,6 +11,35 @@ from enums.cloud_plan import CloudPlan from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class SubscriptionQuery(BaseModel): + plan: str = Field(..., description="Subscription plan") + interval: str = Field(..., description="Billing interval") + + @field_validator("plan") + @classmethod + def validate_plan(cls, value: str) -> str: + if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]: + raise ValueError("Invalid plan") + return value + + @field_validator("interval") + @classmethod + def validate_interval(cls, value: str) -> str: + if value not in {"month", "year"}: + raise ValueError("Invalid interval") + return value + + +class PartnerTenantsPayload(BaseModel): + click_id: str = Field(..., description="Click Id from partner referral link") + + +for model in (SubscriptionQuery, PartnerTenantsPayload): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + @console_ns.route("/billing/subscription") class Subscription(Resource): @@ -18,20 +49,9 @@ class Subscription(Resource): @only_edition_cloud def get(self): current_user, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument( - "plan", - type=str, - required=True, - location="args", - choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM], - ) - .add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) - ) - args = parser.parse_args() + args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id) + return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id) @console_ns.route("/billing/invoices") @@ -65,11 +85,10 @@ class PartnerTenants(Resource): @only_edition_cloud def put(self, partner_key: str): current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json") - args = parser.parse_args() try: - click_id = args["click_id"] + args = PartnerTenantsPayload.model_validate(console_ns.payload or {}) + click_id = args.click_id decoded_partner_key = base64.b64decode(partner_key).decode("utf-8") except Exception: raise BadRequest("Invalid partner_key") diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index 2a6889968c..afc5f92b68 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -1,5 +1,6 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required @@ -9,16 +10,28 @@ from .. import console_ns from ..wraps import account_initialization_required, only_edition_cloud, setup_required +class ComplianceDownloadQuery(BaseModel): + doc_name: str = Field(..., description="Compliance document name") + + +console_ns.schema_model( + ComplianceDownloadQuery.__name__, + ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"), +) + + @console_ns.route("/compliance/download") class ComplianceApi(Resource): + @console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__]) + @console_ns.doc("download_compliance_document") + @console_ns.doc(description="Get compliance document download link") @setup_required @login_required @account_initialization_required @only_edition_cloud def get(self): current_user, current_tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args") - args = parser.parse_args() + args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore ip_address = extract_remote_ip(request) device_info = request.headers.get("User-Agent", "Unknown device") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 5a9c3ef133..2b2f807694 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,4 +1,6 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from constants.languages import languages from controllers.console import console_ns @@ -35,20 +37,26 @@ recommended_app_list_fields = { } -parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args") +class RecommendedAppsQuery(BaseModel): + language: str | None = Field(default=None) + + +console_ns.schema_model( + RecommendedAppsQuery.__name__, + RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"), +) @console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): - @console_ns.expect(parser_apps) + @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) @login_required @account_initialization_required @marshal_with(recommended_app_list_fields) def get(self): # language args - args = parser_apps.parse_args() - - language = args.get("language") + args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + language = args.language if language and language in languages: language_prefix = language elif current_user and current_user.interface_language: diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index f27fa26983..2bebe79eac 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,13 +1,13 @@ import os from flask import session -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config from extensions.ext_database import db -from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -15,6 +15,18 @@ from . import console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class InitValidatePayload(BaseModel): + password: str = Field(..., max_length=30) + + +console_ns.schema_model( + InitValidatePayload.__name__, + InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route("/init") class InitValidateAPI(Resource): @@ -37,12 +49,7 @@ class InitValidateAPI(Resource): @console_ns.doc("validate_init_password") @console_ns.doc(description="Validate initialization password for self-hosted edition") - @console_ns.expect( - console_ns.model( - "InitValidateRequest", - {"password": fields.String(required=True, description="Initialization password", max_length=30)}, - ) - ) + @console_ns.expect(console_ns.models[InitValidatePayload.__name__]) @console_ns.response( 201, "Success", @@ -57,8 +64,8 @@ class InitValidateAPI(Resource): if tenant_count > 0: raise AlreadySetupError() - parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json") - input_password = parser.parse_args()["password"] + payload = InitValidatePayload.model_validate(console_ns.payload) + input_password = payload.password if input_password != os.environ.get("INIT_PASSWORD"): session["is_init_validated"] = False diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 49a4df1b5a..47eef7eb7e 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,7 +1,8 @@ import urllib.parse import httpx -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field import services from controllers.common import helpers @@ -36,17 +37,23 @@ class RemoteFileInfoApi(Resource): } -parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required") +class RemoteFileUploadPayload(BaseModel): + url: str = Field(..., description="URL to fetch") + + +console_ns.schema_model( + RemoteFileUploadPayload.__name__, + RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"), +) @console_ns.route("/remote-files/upload") class RemoteFileUploadApi(Resource): - @console_ns.expect(parser_upload) + @console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__]) @marshal_with(file_fields_with_signed_url) def post(self): - args = parser_upload.parse_args() - - url = args["url"] + args = RemoteFileUploadPayload.model_validate(console_ns.payload) + url = args.url try: resp = ssrf_proxy.head(url=url) diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 0c2a4d797b..7fa02ae280 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,8 +1,9 @@ from flask import request -from flask_restx import Resource, fields, reqparse +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from configs import dify_config -from libs.helper import StrLen, email, extract_remote_ip +from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService @@ -12,6 +13,26 @@ from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class SetupRequestPayload(BaseModel): + email: EmailStr = Field(..., description="Admin email address") + name: str = Field(..., max_length=30, description="Admin name (max 30 characters)") + password: str = Field(..., description="Admin password") + language: str | None = Field(default=None, description="Admin language") + + @field_validator("password") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) + + +console_ns.schema_model( + SetupRequestPayload.__name__, + SetupRequestPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route("/setup") class SetupApi(Resource): @@ -42,17 +63,7 @@ class SetupApi(Resource): @console_ns.doc("setup_system") @console_ns.doc(description="Initialize system setup with admin account") - @console_ns.expect( - console_ns.model( - "SetupRequest", - { - "email": fields.String(required=True, description="Admin email address"), - "name": fields.String(required=True, description="Admin name (max 30 characters)"), - "password": fields.String(required=True, description="Admin password"), - "language": fields.String(required=False, description="Admin language"), - }, - ) - ) + @console_ns.expect(console_ns.models[SetupRequestPayload.__name__]) @console_ns.response( 201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")}) ) @@ -72,22 +83,15 @@ class SetupApi(Resource): if not get_init_validate_status(): raise NotInitValidateError() - parser = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("name", type=StrLen(30), required=True, location="json") - .add_argument("password", type=valid_password, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - ) - args = parser.parse_args() + args = SetupRequestPayload.model_validate(console_ns.payload) # setup RegisterService.setup( - email=args["email"], - name=args["name"], - password=args["password"], + email=args.email, + name=args.name, + password=args.password, ip_address=extract_remote_ip(request), - language=args["language"], + language=args.language, ) return {"result": "success"}, 201 diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 4e3d9d6786..419261ba2a 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,8 +2,10 @@ import json import logging import httpx -from flask_restx import Resource, fields, reqparse +from flask import request +from flask_restx import Resource, fields from packaging import version +from pydantic import BaseModel, Field from configs import dify_config @@ -11,8 +13,14 @@ from . import console_ns logger = logging.getLogger(__name__) -parser = reqparse.RequestParser().add_argument( - "current_version", type=str, required=True, location="args", help="Current application version" + +class VersionQuery(BaseModel): + current_version: str = Field(..., description="Current application version") + + +console_ns.schema_model( + VersionQuery.__name__, + VersionQuery.model_json_schema(ref_template="#/definitions/{model}"), ) @@ -20,7 +28,7 @@ parser = reqparse.RequestParser().add_argument( class VersionApi(Resource): @console_ns.doc("check_version_update") @console_ns.doc(description="Check for application version updates") - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[VersionQuery.__name__]) @console_ns.response( 200, "Success", @@ -37,7 +45,7 @@ class VersionApi(Resource): ) def get(self): """Check for application version updates""" - args = parser.parse_args() + args = VersionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore check_update_url = dify_config.CHECK_UPDATE_URL result = { @@ -57,16 +65,16 @@ class VersionApi(Resource): try: response = httpx.get( check_update_url, - params={"current_version": args["current_version"]}, + params={"current_version": args.current_version}, timeout=httpx.Timeout(timeout=10.0, connect=3.0), ) except Exception as error: logger.warning("Check update version error: %s.", str(error)) - result["version"] = args["current_version"] + result["version"] = args.current_version return result content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): + if _has_new_version(latest_version=content["version"], current_version=f"{args.current_version}"): result["version"] = content["version"] result["release_date"] = content["releaseDate"] result["release_notes"] = content["releaseNotes"] diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 6334314988..55eaa2f09f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -37,7 +37,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now -from libs.helper import TimestampField, email, extract_remote_ip, timezone +from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import Account, AccountIntegrate, InvitationCode from services.account_service import AccountService @@ -111,14 +111,9 @@ class AccountDeletePayload(BaseModel): class AccountDeletionFeedbackPayload(BaseModel): - email: str + email: EmailStr feedback: str - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class EducationActivatePayload(BaseModel): token: str @@ -133,45 +128,25 @@ class EducationAutocompleteQuery(BaseModel): class ChangeEmailSendPayload(BaseModel): - email: str + email: EmailStr language: str | None = None phase: str | None = None token: str | None = None - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class ChangeEmailValidityPayload(BaseModel): - email: str + email: EmailStr code: str token: str - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class ChangeEmailResetPayload(BaseModel): - new_email: str + new_email: EmailStr token: str - @field_validator("new_email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) - class CheckEmailUniquePayload(BaseModel): - email: str - - @field_validator("email") - @classmethod - def validate_email(cls, value: str) -> str: - return email(value) + email: EmailStr def reg(cls: type[BaseModel]): diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index d320855f29..64f47f426a 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,7 +1,8 @@ from urllib.parse import quote from flask import Response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import NotFound import services @@ -11,6 +12,26 @@ from extensions.ext_database import db from services.account_service import TenantService from services.file_service import FileService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class FileSignatureQuery(BaseModel): + timestamp: str = Field(..., description="Unix timestamp used in the signature") + nonce: str = Field(..., description="Random string for signature") + sign: str = Field(..., description="HMAC signature") + + +class FilePreviewQuery(FileSignatureQuery): + as_attachment: bool = Field(default=False, description="Whether to download as attachment") + + +files_ns.schema_model( + FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +files_ns.schema_model( + FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + @files_ns.route("//image-preview") class ImagePreviewApi(Resource): @@ -36,12 +57,10 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get("timestamp") - nonce = request.args.get("nonce") - sign = request.args.get("sign") - - if not timestamp or not nonce or not sign: - return {"content": "Invalid request."}, 400 + args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + timestamp = args.timestamp + nonce = args.nonce + sign = args.sign try: generator, mimetype = FileService(db.engine).get_image_preview( @@ -80,25 +99,14 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - parser = ( - reqparse.RequestParser() - .add_argument("timestamp", type=str, required=True, location="args") - .add_argument("nonce", type=str, required=True, location="args") - .add_argument("sign", type=str, required=True, location="args") - .add_argument("as_attachment", type=bool, required=False, default=False, location="args") - ) - - args = parser.parse_args() - - if not args["timestamp"] or not args["nonce"] or not args["sign"]: - return {"content": "Invalid request."}, 400 + args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( file_id=file_id, - timestamp=args["timestamp"], - nonce=args["nonce"], - sign=args["sign"], + timestamp=args.timestamp, + nonce=args.nonce, + sign=args.sign, ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() @@ -125,7 +133,7 @@ class FilePreviewApi(Resource): response.headers["Accept-Ranges"] = "bytes" if upload_file.size > 0: response.headers["Content-Length"] = str(upload_file.size) - if args["as_attachment"]: + if args.as_attachment: encoded_filename = quote(upload_file.name) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" response.headers["Content-Type"] = "application/octet-stream" diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index ecaeb85821..c487a0a915 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,7 +1,8 @@ from urllib.parse import quote -from flask import Response -from flask_restx import Resource, reqparse +from flask import Response, request +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound from controllers.common.errors import UnsupportedFileTypeError @@ -10,6 +11,20 @@ from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db as global_db +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ToolFileQuery(BaseModel): + timestamp: str = Field(..., description="Unix timestamp") + nonce: str = Field(..., description="Random nonce") + sign: str = Field(..., description="HMAC signature") + as_attachment: bool = Field(default=False, description="Download as attachment") + + +files_ns.schema_model( + ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + @files_ns.route("/tools/.") class ToolFileApi(Resource): @@ -36,18 +51,8 @@ class ToolFileApi(Resource): def get(self, file_id, extension): file_id = str(file_id) - parser = ( - reqparse.RequestParser() - .add_argument("timestamp", type=str, required=True, location="args") - .add_argument("nonce", type=str, required=True, location="args") - .add_argument("sign", type=str, required=True, location="args") - .add_argument("as_attachment", type=bool, required=False, default=False, location="args") - ) - - args = parser.parse_args() - if not verify_tool_file_signature( - file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"] - ): + args = ToolFileQuery.model_validate(request.args.to_dict()) + if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign): raise Forbidden("Invalid request.") try: @@ -69,7 +74,7 @@ class ToolFileApi(Resource): ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args["as_attachment"]: + if args.as_attachment: encoded_filename = quote(tool_file.name) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index a09e24e2d9..6096a87c56 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -1,40 +1,45 @@ from mimetypes import guess_extension -from flask_restx import Resource, reqparse +from flask import request +from flask_restx import Resource from flask_restx.api import HTTPStatus +from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services -from controllers.common.errors import ( - FileTooLargeError, - UnsupportedFileTypeError, -) -from controllers.console.wraps import setup_required -from controllers.files import files_ns -from controllers.inner_api.plugin.wraps import get_user from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager from fields.file_fields import build_file_model -# Define parser for both documentation and validation -upload_parser = ( - reqparse.RequestParser() - .add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") - .add_argument( - "timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" - ) - .add_argument("nonce", type=str, required=True, location="args", help="Random string for signature verification") - .add_argument("sign", type=str, required=True, location="args", help="HMAC signature for request validation") - .add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier") - .add_argument("user_id", type=str, required=False, location="args", help="User identifier") +from ..common.errors import ( + FileTooLargeError, + UnsupportedFileTypeError, +) +from ..console.wraps import setup_required +from ..files import files_ns +from ..inner_api.plugin.wraps import get_user + +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class PluginUploadQuery(BaseModel): + timestamp: str = Field(..., description="Unix timestamp for signature verification") + nonce: str = Field(..., description="Random nonce for signature verification") + sign: str = Field(..., description="HMAC signature") + tenant_id: str = Field(..., description="Tenant identifier") + user_id: str | None = Field(default=None, description="User identifier") + + +files_ns.schema_model( + PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @setup_required - @files_ns.expect(upload_parser) + @files_ns.expect(files_ns.models[PluginUploadQuery.__name__]) @files_ns.doc("upload_plugin_file") @files_ns.doc(description="Upload a file for plugin usage with signature verification") @files_ns.doc( @@ -62,15 +67,17 @@ class PluginUploadFileApi(Resource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - # Parse and validate all arguments - args = upload_parser.parse_args() + args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - file: FileStorage = args["file"] - timestamp: str = args["timestamp"] - nonce: str = args["nonce"] - sign: str = args["sign"] - tenant_id: str = args["tenant_id"] - user_id: str | None = args.get("user_id") + file: FileStorage | None = request.files.get("file") + if file is None: + raise Forbidden("File is required.") + + timestamp = args.timestamp + nonce = args.nonce + sign = args.sign + tenant_id = args.tenant_id + user_id = args.user_id user = get_user(tenant_id, user_id) filename: str | None = file.filename diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index e1c96fb050..84266ab0fa 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -256,7 +256,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] now = datetime_utils.naive_utc_now() last_update = _get_last_update_timestamp(cache_key) - if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: + if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS: # type: ignore update_values["last_used"] = values.last_used _set_last_update_timestamp(cache_key, now) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 588fbae285..5e75bc36b0 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union import redis from redis import RedisError @@ -245,7 +245,12 @@ def init_app(app: DifyApp): app.extensions["redis"] = redis_client -def redis_fallback(default_return: Any | None = None): +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") + + +def redis_fallback(default_return: T | None = None): # type: ignore """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. @@ -253,9 +258,9 @@ def redis_fallback(default_return: Any | None = None): default_return: The value to return when a Redis operation fails. Defaults to None. """ - def decorator(func: Callable): + def decorator(func: Callable[P, R]): @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs): try: return func(*args, **kwargs) except RedisError as e: diff --git a/api/libs/helper.py b/api/libs/helper.py index 1013c3b878..0506e0ed5f 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -10,12 +10,13 @@ import uuid from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields from pydantic import BaseModel +from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator @@ -103,6 +104,9 @@ def email(email): raise ValueError(error) +EmailStr = Annotated[str, AfterValidator(email)] + + def uuid_value(value): if value == "": return str(value) diff --git a/api/pyrefly.toml b/api/pyrefly.toml new file mode 100644 index 0000000000..80ffba019d --- /dev/null +++ b/api/pyrefly.toml @@ -0,0 +1,10 @@ +project-includes = ["."] +project-excludes = [ + "tests/", + ".venv", + "migrations/", + "core/rag", +] +python-platform = "linux" +python-version = "3.11.0" +infer-with-first-use = false diff --git a/api/services/account_service.py b/api/services/account_service.py index ac6d1bde77..5a549dc318 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1259,7 +1259,7 @@ class RegisterService: return f"member_invite:token:{token}" @classmethod - def setup(cls, email: str, name: str, password: str, ip_address: str, language: str): + def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None): """ Setup dify @@ -1267,6 +1267,7 @@ class RegisterService: :param name: username :param password: password :param ip_address: ip address + :param language: language """ try: account = AccountService.create_account( @@ -1414,7 +1415,7 @@ class RegisterService: return data is not None @classmethod - def revoke_token(cls, workspace_id: str, email: str, token: str): + def revoke_token(cls, workspace_id: str | None, email: str | None, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" @@ -1423,7 +1424,9 @@ class RegisterService: redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None: + def get_invitation_if_token_valid( + cls, workspace_id: str | None, email: str | None, token: str + ) -> dict[str, Any] | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None