diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index 597ff1f6c5..b3e5c9619f 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -1,6 +1,8 @@ import logging -from flask_restx import Resource, marshal_with, reqparse +from flask import request +from flask_restx import Resource, marshal_with +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -18,16 +20,30 @@ from ..app.wraps import get_app_model from ..wraps import account_initialization_required, edit_permission_required, setup_required logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required") +class Parser(BaseModel): + node_id: str + + +class ParserEnable(BaseModel): + trigger_id: str + enable_trigger: bool + + +console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + +console_ns.schema_model( + ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) @console_ns.route("/apps//workflows/triggers/webhook") class WebhookTriggerApi(Resource): """Webhook Trigger API""" - @console_ns.expect(parser) + @console_ns.expect(console_ns.models[Parser.__name__], validate=True) @setup_required @login_required @account_initialization_required @@ -35,9 +51,9 @@ class WebhookTriggerApi(Resource): @marshal_with(webhook_trigger_fields) def get(self, app_model: App): """Get webhook trigger for a node""" - args = parser.parse_args() + args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore - node_id = str(args["node_id"]) + node_id = args.node_id with Session(db.engine) as session: # Get webhook trigger for this app and node @@ -96,16 +112,9 @@ class AppTriggersApi(Resource): return {"data": triggers} -parser_enable = ( - reqparse.RequestParser() - .add_argument("trigger_id", type=str, required=True, nullable=False, location="json") - .add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json") -) - - @console_ns.route("/apps//trigger-enable") class AppTriggerEnableApi(Resource): - @console_ns.expect(parser_enable) + @console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True) @setup_required @login_required @account_initialization_required @@ -114,12 +123,11 @@ class AppTriggerEnableApi(Resource): @marshal_with(trigger_fields) def post(self, app_model: App): """Update app trigger (enable/disable)""" - args = parser_enable.parse_args() + args = ParserEnable.model_validate(console_ns.payload) assert current_user.current_tenant_id is not None - trigger_id = args["trigger_id"] - + trigger_id = args.trigger_id with Session(db.engine) as session: # Find the trigger using select trigger = session.execute( @@ -134,7 +142,7 @@ class AppTriggerEnableApi(Resource): raise NotFound("Trigger not found") # Update status based on enable_trigger boolean - trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED + trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED session.commit() session.refresh(trigger) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 838cd3ee95..b4d1b42657 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,8 +1,10 @@ from datetime import datetime +from typing import Literal import pytz from flask import request -from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -42,20 +44,198 @@ from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -def _init_parser(): - parser = reqparse.RequestParser() - if dify_config.EDITION == "CLOUD": - parser.add_argument("invitation_code", type=str, location="json") - parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument( - "timezone", type=timezone, required=True, location="json" - ) - return parser + +class AccountInitPayload(BaseModel): + interface_language: str + timezone: str + invitation_code: str | None = None + + @field_validator("interface_language") + @classmethod + def validate_language(cls, value: str) -> str: + return supported_language(value) + + @field_validator("timezone") + @classmethod + def validate_timezone(cls, value: str) -> str: + return timezone(value) + + +class AccountNamePayload(BaseModel): + name: str = Field(min_length=3, max_length=30) + + +class AccountAvatarPayload(BaseModel): + avatar: str + + +class AccountInterfaceLanguagePayload(BaseModel): + interface_language: str + + @field_validator("interface_language") + @classmethod + def validate_language(cls, value: str) -> str: + return supported_language(value) + + +class AccountInterfaceThemePayload(BaseModel): + interface_theme: Literal["light", "dark"] + + +class AccountTimezonePayload(BaseModel): + timezone: str + + @field_validator("timezone") + @classmethod + def validate_timezone(cls, value: str) -> str: + return timezone(value) + + +class AccountPasswordPayload(BaseModel): + password: str | None = None + new_password: str + repeat_new_password: str + + @model_validator(mode="after") + def check_passwords_match(self) -> "AccountPasswordPayload": + if self.new_password != self.repeat_new_password: + raise RepeatPasswordNotMatchError() + return self + + +class AccountDeletePayload(BaseModel): + token: str + code: str + + +class AccountDeletionFeedbackPayload(BaseModel): + email: str + feedback: str + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + return email(value) + + +class EducationActivatePayload(BaseModel): + token: str + institution: str + role: str + + +class EducationAutocompleteQuery(BaseModel): + keywords: str + page: int = 0 + limit: int = 20 + + +class ChangeEmailSendPayload(BaseModel): + email: str + language: str | None = None + phase: str | None = None + token: str | None = None + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + return email(value) + + +class ChangeEmailValidityPayload(BaseModel): + email: str + code: str + token: str + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + return email(value) + + +class ChangeEmailResetPayload(BaseModel): + new_email: str + token: str + + @field_validator("new_email") + @classmethod + def validate_email(cls, value: str) -> str: + return email(value) + + +class CheckEmailUniquePayload(BaseModel): + email: str + + @field_validator("email") + @classmethod + def validate_email(cls, value: str) -> str: + return email(value) + + +console_ns.schema_model( + AccountInitPayload.__name__, AccountInitPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + AccountNamePayload.__name__, AccountNamePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + AccountAvatarPayload.__name__, AccountAvatarPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) +console_ns.schema_model( + AccountInterfaceLanguagePayload.__name__, + AccountInterfaceLanguagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + AccountInterfaceThemePayload.__name__, + AccountInterfaceThemePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + AccountTimezonePayload.__name__, + AccountTimezonePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + AccountPasswordPayload.__name__, + AccountPasswordPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + AccountDeletePayload.__name__, + AccountDeletePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + AccountDeletionFeedbackPayload.__name__, + AccountDeletionFeedbackPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + EducationActivatePayload.__name__, + EducationActivatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + EducationAutocompleteQuery.__name__, + EducationAutocompleteQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + ChangeEmailSendPayload.__name__, + ChangeEmailSendPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + ChangeEmailValidityPayload.__name__, + ChangeEmailValidityPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + ChangeEmailResetPayload.__name__, + ChangeEmailResetPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + CheckEmailUniquePayload.__name__, + CheckEmailUniquePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) @console_ns.route("/account/init") class AccountInitApi(Resource): - @console_ns.expect(_init_parser()) + @console_ns.expect(console_ns.models[AccountInitPayload.__name__]) @setup_required @login_required def post(self): @@ -64,17 +244,18 @@ class AccountInitApi(Resource): if account.status == "active": raise AccountAlreadyInitedError() - args = _init_parser().parse_args() + payload = console_ns.payload or {} + args = AccountInitPayload.model_validate(payload) if dify_config.EDITION == "CLOUD": - if not args["invitation_code"]: + if not args.invitation_code: raise ValueError("invitation_code is required") # check invitation code invitation_code = ( db.session.query(InvitationCode) .where( - InvitationCode.code == args["invitation_code"], + InvitationCode.code == args.invitation_code, InvitationCode.status == "unused", ) .first() @@ -88,8 +269,8 @@ class AccountInitApi(Resource): invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id - account.interface_language = args["interface_language"] - account.timezone = args["timezone"] + account.interface_language = args.interface_language + account.timezone = args.timezone account.interface_theme = "light" account.status = "active" account.initialized_at = naive_utc_now() @@ -110,137 +291,104 @@ class AccountProfileApi(Resource): return current_user -parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") - - @console_ns.route("/account/name") class AccountNameApi(Resource): - @console_ns.expect(parser_name) + @console_ns.expect(console_ns.models[AccountNamePayload.__name__]) @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - args = parser_name.parse_args() - - # Validate account name length - if len(args["name"]) < 3 or len(args["name"]) > 30: - raise ValueError("Account name must be between 3 and 30 characters.") - - updated_account = AccountService.update_account(current_user, name=args["name"]) + payload = console_ns.payload or {} + args = AccountNamePayload.model_validate(payload) + updated_account = AccountService.update_account(current_user, name=args.name) return updated_account -parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json") - - @console_ns.route("/account/avatar") class AccountAvatarApi(Resource): - @console_ns.expect(parser_avatar) + @console_ns.expect(console_ns.models[AccountAvatarPayload.__name__]) @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - args = parser_avatar.parse_args() + payload = console_ns.payload or {} + args = AccountAvatarPayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) + updated_account = AccountService.update_account(current_user, avatar=args.avatar) return updated_account -parser_interface = reqparse.RequestParser().add_argument( - "interface_language", type=supported_language, required=True, location="json" -) - - @console_ns.route("/account/interface-language") class AccountInterfaceLanguageApi(Resource): - @console_ns.expect(parser_interface) + @console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__]) @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - args = parser_interface.parse_args() + payload = console_ns.payload or {} + args = AccountInterfaceLanguagePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) + updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) return updated_account -parser_theme = reqparse.RequestParser().add_argument( - "interface_theme", type=str, choices=["light", "dark"], required=True, location="json" -) - - @console_ns.route("/account/interface-theme") class AccountInterfaceThemeApi(Resource): - @console_ns.expect(parser_theme) + @console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__]) @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - args = parser_theme.parse_args() + payload = console_ns.payload or {} + args = AccountInterfaceThemePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) + updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) return updated_account -parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json") - - @console_ns.route("/account/timezone") class AccountTimezoneApi(Resource): - @console_ns.expect(parser_timezone) + @console_ns.expect(console_ns.models[AccountTimezonePayload.__name__]) @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - args = parser_timezone.parse_args() + payload = console_ns.payload or {} + args = AccountTimezonePayload.model_validate(payload) - # Validate timezone string, e.g. America/New_York, Asia/Shanghai - if args["timezone"] not in pytz.all_timezones: - raise ValueError("Invalid timezone string.") - - updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) + updated_account = AccountService.update_account(current_user, timezone=args.timezone) return updated_account -parser_pw = ( - reqparse.RequestParser() - .add_argument("password", type=str, required=False, location="json") - .add_argument("new_password", type=str, required=True, location="json") - .add_argument("repeat_new_password", type=str, required=True, location="json") -) - - @console_ns.route("/account/password") class AccountPasswordApi(Resource): - @console_ns.expect(parser_pw) + @console_ns.expect(console_ns.models[AccountPasswordPayload.__name__]) @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): current_user, _ = current_account_with_tenant() - args = parser_pw.parse_args() - - if args["new_password"] != args["repeat_new_password"]: - raise RepeatPasswordNotMatchError() + payload = console_ns.payload or {} + args = AccountPasswordPayload.model_validate(payload) try: - AccountService.update_account_password(current_user, args["password"], args["new_password"]) + AccountService.update_account_password(current_user, args.password, args.new_password) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -316,25 +464,19 @@ class AccountDeleteVerifyApi(Resource): return {"result": "success", "data": token} -parser_delete = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") -) - - @console_ns.route("/account/delete") class AccountDeleteApi(Resource): - @console_ns.expect(parser_delete) + @console_ns.expect(console_ns.models[AccountDeletePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): account, _ = current_account_with_tenant() - args = parser_delete.parse_args() + payload = console_ns.payload or {} + args = AccountDeletePayload.model_validate(payload) - if not AccountService.verify_account_deletion_code(args["token"], args["code"]): + if not AccountService.verify_account_deletion_code(args.token, args.code): raise InvalidAccountDeletionCodeError() AccountService.delete_account(account) @@ -342,21 +484,15 @@ class AccountDeleteApi(Resource): return {"result": "success"} -parser_feedback = ( - reqparse.RequestParser() - .add_argument("email", type=str, required=True, location="json") - .add_argument("feedback", type=str, required=True, location="json") -) - - @console_ns.route("/account/delete/feedback") class AccountDeleteUpdateFeedbackApi(Resource): - @console_ns.expect(parser_feedback) + @console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__]) @setup_required def post(self): - args = parser_feedback.parse_args() + payload = console_ns.payload or {} + args = AccountDeletionFeedbackPayload.model_validate(payload) - BillingService.update_account_deletion_feedback(args["email"], args["feedback"]) + BillingService.update_account_deletion_feedback(args.email, args.feedback) return {"result": "success"} @@ -379,14 +515,6 @@ class EducationVerifyApi(Resource): return BillingService.EducationIdentity.verify(account.id, account.email) -parser_edu = ( - reqparse.RequestParser() - .add_argument("token", type=str, required=True, location="json") - .add_argument("institution", type=str, required=True, location="json") - .add_argument("role", type=str, required=True, location="json") -) - - @console_ns.route("/account/education") class EducationApi(Resource): status_fields = { @@ -396,7 +524,7 @@ class EducationApi(Resource): "allow_refresh": fields.Boolean, } - @console_ns.expect(parser_edu) + @console_ns.expect(console_ns.models[EducationActivatePayload.__name__]) @setup_required @login_required @account_initialization_required @@ -405,9 +533,10 @@ class EducationApi(Resource): def post(self): account, _ = current_account_with_tenant() - args = parser_edu.parse_args() + payload = console_ns.payload or {} + args = EducationActivatePayload.model_validate(payload) - return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"]) + return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role) @setup_required @login_required @@ -425,14 +554,6 @@ class EducationApi(Resource): return res -parser_autocomplete = ( - reqparse.RequestParser() - .add_argument("keywords", type=str, required=True, location="args") - .add_argument("page", type=int, required=False, location="args", default=0) - .add_argument("limit", type=int, required=False, location="args", default=20) -) - - @console_ns.route("/account/education/autocomplete") class EducationAutoCompleteApi(Resource): data_fields = { @@ -441,7 +562,7 @@ class EducationAutoCompleteApi(Resource): "has_next": fields.Boolean, } - @console_ns.expect(parser_autocomplete) + @console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__]) @setup_required @login_required @account_initialization_required @@ -449,46 +570,39 @@ class EducationAutoCompleteApi(Resource): @cloud_edition_billing_enabled @marshal_with(data_fields) def get(self): - args = parser_autocomplete.parse_args() + payload = request.args.to_dict(flat=True) # type: ignore + args = EducationAutocompleteQuery.model_validate(payload) - return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) - - -parser_change_email = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("language", type=str, required=False, location="json") - .add_argument("phase", type=str, required=False, location="json") - .add_argument("token", type=str, required=False, location="json") -) + return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) @console_ns.route("/account/change-email") class ChangeEmailSendEmailApi(Resource): - @console_ns.expect(parser_change_email) + @console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__]) @enable_change_email @setup_required @login_required @account_initialization_required def post(self): current_user, _ = current_account_with_tenant() - args = parser_change_email.parse_args() + payload = console_ns.payload or {} + args = ChangeEmailSendPayload.model_validate(payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" account = None - user_email = args["email"] - if args["phase"] is not None and args["phase"] == "new_email": - if args["token"] is None: + user_email = args.email + if args.phase is not None and args.phase == "new_email": + if args.token is None: raise InvalidTokenError() - reset_data = AccountService.get_change_email_data(args["token"]) + reset_data = AccountService.get_change_email_data(args.token) if reset_data is None: raise InvalidTokenError() user_email = reset_data.get("email", "") @@ -497,118 +611,103 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidEmailError() else: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() if account is None: raise AccountNotFound() token = AccountService.send_change_email_email( - account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"] + account=account, email=args.email, old_email=user_email, language=language, phase=args.phase ) return {"result": "success", "data": token} -parser_validity = ( - reqparse.RequestParser() - .add_argument("email", type=email, required=True, location="json") - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/account/change-email/validity") class ChangeEmailCheckApi(Resource): - @console_ns.expect(parser_validity) + @console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__]) @enable_change_email @setup_required @login_required @account_initialization_required def post(self): - args = parser_validity.parse_args() + payload = console_ns.payload or {} + args = ChangeEmailValidityPayload.model_validate(payload) - user_email = args["email"] + user_email = args.email - is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"]) + is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email) if is_change_email_error_rate_limit: raise EmailChangeLimitError() - token_data = AccountService.get_change_email_data(args["token"]) + token_data = AccountService.get_change_email_data(args.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): - AccountService.add_change_email_error_rate_limit(args["email"]) + if args.code != token_data.get("code"): + AccountService.add_change_email_error_rate_limit(args.email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_change_email_token(args["token"]) + AccountService.revoke_change_email_token(args.token) # Refresh token data by generating a new token _, new_token = AccountService.generate_change_email_token( - user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={} + user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} ) - AccountService.reset_change_email_error_rate_limit(args["email"]) + AccountService.reset_change_email_error_rate_limit(args.email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} -parser_reset = ( - reqparse.RequestParser() - .add_argument("new_email", type=email, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/account/change-email/reset") class ChangeEmailResetApi(Resource): - @console_ns.expect(parser_reset) + @console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__]) @enable_change_email @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): - args = parser_reset.parse_args() + payload = console_ns.payload or {} + args = ChangeEmailResetPayload.model_validate(payload) - if AccountService.is_account_in_freeze(args["new_email"]): + if AccountService.is_account_in_freeze(args.new_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args["new_email"]): + if not AccountService.check_email_unique(args.new_email): raise EmailAlreadyInUseError() - reset_data = AccountService.get_change_email_data(args["token"]) + reset_data = AccountService.get_change_email_data(args.token) if not reset_data: raise InvalidTokenError() - AccountService.revoke_change_email_token(args["token"]) + AccountService.revoke_change_email_token(args.token) old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() if current_user.email != old_email: raise AccountNotFound() - updated_account = AccountService.update_account_email(current_user, email=args["new_email"]) + updated_account = AccountService.update_account_email(current_user, email=args.new_email) AccountService.send_change_email_completed_notify_email( - email=args["new_email"], + email=args.new_email, ) return updated_account -parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json") - - @console_ns.route("/account/change-email/check-email-unique") class CheckEmailUnique(Resource): - @console_ns.expect(parser_check) + @console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__]) @setup_required def post(self): - args = parser_check.parse_args() - if AccountService.is_account_in_freeze(args["email"]): + payload = console_ns.payload or {} + args = CheckEmailUniquePayload.model_validate(payload) + if AccountService.is_account_in_freeze(args.email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args["email"]): + if not AccountService.check_email_unique(args.email): raise EmailAlreadyInUseError() return {"result": "success"} diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index f17f8e4bcf..f72d247398 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,7 +1,8 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field import services from configs import dify_config @@ -31,6 +32,53 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class MemberInvitePayload(BaseModel): + emails: list[str] = Field(default_factory=list) + role: TenantAccountRole + language: str | None = None + + +class MemberRoleUpdatePayload(BaseModel): + role: str + + +class OwnerTransferEmailPayload(BaseModel): + language: str | None = None + + +class OwnerTransferCheckPayload(BaseModel): + code: str + token: str + + +class OwnerTransferPayload(BaseModel): + token: str + + +console_ns.schema_model( + MemberInvitePayload.__name__, + MemberInvitePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + MemberRoleUpdatePayload.__name__, + MemberRoleUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + OwnerTransferEmailPayload.__name__, + OwnerTransferEmailPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + OwnerTransferCheckPayload.__name__, + OwnerTransferCheckPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + OwnerTransferPayload.__name__, + OwnerTransferPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + @console_ns.route("/workspaces/current/members") class MemberListApi(Resource): @@ -48,29 +96,22 @@ class MemberListApi(Resource): return {"result": "success", "accounts": members}, 200 -parser_invite = ( - reqparse.RequestParser() - .add_argument("emails", type=list, required=True, location="json") - .add_argument("role", type=str, required=True, default="admin", location="json") - .add_argument("language", type=str, required=False, location="json") -) - - @console_ns.route("/workspaces/current/members/invite-email") class MemberInviteEmailApi(Resource): """Invite a new member by email.""" - @console_ns.expect(parser_invite) + @console_ns.expect(console_ns.models[MemberInvitePayload.__name__]) @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("members") def post(self): - args = parser_invite.parse_args() + payload = console_ns.payload or {} + args = MemberInvitePayload.model_validate(payload) - invitee_emails = args["emails"] - invitee_role = args["role"] - interface_language = args["language"] + invitee_emails = args.emails + invitee_role = args.role + interface_language = args.language if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 current_user, _ = current_account_with_tenant() @@ -146,20 +187,18 @@ class MemberCancelInviteApi(Resource): }, 200 -parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json") - - @console_ns.route("/workspaces/current/members//update-role") class MemberUpdateRoleApi(Resource): """Update member role.""" - @console_ns.expect(parser_update) + @console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required def put(self, member_id): - args = parser_update.parse_args() - new_role = args["role"] + payload = console_ns.payload or {} + args = MemberRoleUpdatePayload.model_validate(payload) + new_role = args.role if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 @@ -197,20 +236,18 @@ class DatasetOperatorMemberListApi(Resource): return {"result": "success", "accounts": members}, 200 -parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json") - - @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") class SendOwnerTransferEmailApi(Resource): """Send owner transfer email.""" - @console_ns.expect(parser_send) + @console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__]) @setup_required @login_required @account_initialization_required @is_allow_transfer_owner def post(self): - args = parser_send.parse_args() + payload = console_ns.payload or {} + args = OwnerTransferEmailPayload.model_validate(payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() @@ -221,7 +258,7 @@ class SendOwnerTransferEmailApi(Resource): if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() - if args["language"] is not None and args["language"] == "zh-Hans": + if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" @@ -238,22 +275,16 @@ class SendOwnerTransferEmailApi(Resource): return {"result": "success", "data": token} -parser_owner = ( - reqparse.RequestParser() - .add_argument("code", type=str, required=True, location="json") - .add_argument("token", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/members/owner-transfer-check") class OwnerTransferCheckApi(Resource): - @console_ns.expect(parser_owner) + @console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__]) @setup_required @login_required @account_initialization_required @is_allow_transfer_owner def post(self): - args = parser_owner.parse_args() + payload = console_ns.payload or {} + args = OwnerTransferCheckPayload.model_validate(payload) # check if the current user is the owner of the workspace current_user, _ = current_account_with_tenant() if not current_user.current_tenant: @@ -267,41 +298,37 @@ class OwnerTransferCheckApi(Resource): if is_owner_transfer_error_rate_limit: raise OwnerTransferLimitError() - token_data = AccountService.get_owner_transfer_data(args["token"]) + token_data = AccountService.get_owner_transfer_data(args.token) if token_data is None: raise InvalidTokenError() if user_email != token_data.get("email"): raise InvalidEmailError() - if args["code"] != token_data.get("code"): + if args.code != token_data.get("code"): AccountService.add_owner_transfer_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token - AccountService.revoke_owner_transfer_token(args["token"]) + AccountService.revoke_owner_transfer_token(args.token) # Refresh token data by generating a new token - _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={}) + _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={}) AccountService.reset_owner_transfer_error_rate_limit(user_email) return {"is_valid": True, "email": token_data.get("email"), "token": new_token} -parser_owner_transfer = reqparse.RequestParser().add_argument( - "token", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/members//owner-transfer") class OwnerTransfer(Resource): - @console_ns.expect(parser_owner_transfer) + @console_ns.expect(console_ns.models[OwnerTransferPayload.__name__]) @setup_required @login_required @account_initialization_required @is_allow_transfer_owner def post(self, member_id): - args = parser_owner_transfer.parse_args() + payload = console_ns.payload or {} + args = OwnerTransferPayload.model_validate(payload) # check if the current user is the owner of the workspace current_user, _ = current_account_with_tenant() @@ -313,14 +340,14 @@ class OwnerTransfer(Resource): if current_user.id == str(member_id): raise CannotTransferOwnerToSelfError() - transfer_token_data = AccountService.get_owner_transfer_data(args["token"]) + transfer_token_data = AccountService.get_owner_transfer_data(args.token) if not transfer_token_data: raise InvalidTokenError() if transfer_token_data.get("email") != current_user.email: raise InvalidEmailError() - AccountService.revoke_owner_transfer_token(args["token"]) + AccountService.revoke_owner_transfer_token(args.token) member = db.session.get(Account, str(member_id)) if not member: diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 8ca69121bf..d40748d5e3 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,31 +1,123 @@ import io +from typing import Any, Literal -from flask import send_file -from flask_restx import Resource, reqparse +from flask import request, send_file +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder -from libs.helper import StrLen, uuid_value +from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService -parser_model = reqparse.RequestParser().add_argument( - "model_type", - type=str, - required=False, - nullable=True, - choices=[mt.value for mt in ModelType], - location="args", +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class ParserModelList(BaseModel): + model_type: ModelType | None = None + + +class ParserCredentialId(BaseModel): + credential_id: str | None = None + + @field_validator("credential_id") + @classmethod + def validate_optional_credential_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ParserCredentialCreate(BaseModel): + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + + +class ParserCredentialUpdate(BaseModel): + credential_id: str + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + + @field_validator("credential_id") + @classmethod + def validate_update_credential_id(cls, value: str) -> str: + return uuid_value(value) + + +class ParserCredentialDelete(BaseModel): + credential_id: str + + @field_validator("credential_id") + @classmethod + def validate_delete_credential_id(cls, value: str) -> str: + return uuid_value(value) + + +class ParserCredentialSwitch(BaseModel): + credential_id: str + + @field_validator("credential_id") + @classmethod + def validate_switch_credential_id(cls, value: str) -> str: + return uuid_value(value) + + +class ParserCredentialValidate(BaseModel): + credentials: dict[str, Any] + + +class ParserPreferredProviderType(BaseModel): + preferred_provider_type: Literal["system", "custom"] + + +console_ns.schema_model( + ParserModelList.__name__, ParserModelList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserCredentialId.__name__, + ParserCredentialId.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserCredentialCreate.__name__, + ParserCredentialCreate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserCredentialUpdate.__name__, + ParserCredentialUpdate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserCredentialDelete.__name__, + ParserCredentialDelete.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserCredentialSwitch.__name__, + ParserCredentialSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserCredentialValidate.__name__, + ParserCredentialValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserPreferredProviderType.__name__, + ParserPreferredProviderType.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) @console_ns.route("/workspaces/current/model-providers") class ModelProviderListApi(Resource): - @console_ns.expect(parser_model) + @console_ns.expect(console_ns.models[ParserModelList.__name__]) @setup_required @login_required @account_initialization_required @@ -33,38 +125,18 @@ class ModelProviderListApi(Resource): _, current_tenant_id = current_account_with_tenant() tenant_id = current_tenant_id - args = parser_model.parse_args() + payload = request.args.to_dict(flat=True) # type: ignore + args = ParserModelList.model_validate(payload) model_provider_service = ModelProviderService() - provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) + provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type) return jsonable_encoder({"data": provider_list}) -parser_cred = reqparse.RequestParser().add_argument( - "credential_id", type=uuid_value, required=False, nullable=True, location="args" -) -parser_post_cred = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - -parser_put_cred = ( - reqparse.RequestParser() - .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - -parser_delete_cred = reqparse.RequestParser().add_argument( - "credential_id", type=uuid_value, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/model-providers//credentials") class ModelProviderCredentialApi(Resource): - @console_ns.expect(parser_cred) + @console_ns.expect(console_ns.models[ParserCredentialId.__name__]) @setup_required @login_required @account_initialization_required @@ -72,23 +144,25 @@ class ModelProviderCredentialApi(Resource): _, current_tenant_id = current_account_with_tenant() tenant_id = current_tenant_id # if credential_id is not provided, return current used credential - args = parser_cred.parse_args() + payload = request.args.to_dict(flat=True) # type: ignore + args = ParserCredentialId.model_validate(payload) model_provider_service = ModelProviderService() credentials = model_provider_service.get_provider_credential( - tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id") + tenant_id=tenant_id, provider=provider, credential_id=args.credential_id ) return {"credentials": credentials} - @console_ns.expect(parser_post_cred) + @console_ns.expect(console_ns.models[ParserCredentialCreate.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): _, current_tenant_id = current_account_with_tenant() - args = parser_post_cred.parse_args() + payload = console_ns.payload or {} + args = ParserCredentialCreate.model_validate(payload) model_provider_service = ModelProviderService() @@ -96,15 +170,15 @@ class ModelProviderCredentialApi(Resource): model_provider_service.create_provider_credential( tenant_id=current_tenant_id, provider=provider, - credentials=args["credentials"], - credential_name=args["name"], + credentials=args.credentials, + credential_name=args.name, ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) return {"result": "success"}, 201 - @console_ns.expect(parser_put_cred) + @console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -112,7 +186,8 @@ class ModelProviderCredentialApi(Resource): def put(self, provider: str): _, current_tenant_id = current_account_with_tenant() - args = parser_put_cred.parse_args() + payload = console_ns.payload or {} + args = ParserCredentialUpdate.model_validate(payload) model_provider_service = ModelProviderService() @@ -120,71 +195,64 @@ class ModelProviderCredentialApi(Resource): model_provider_service.update_provider_credential( tenant_id=current_tenant_id, provider=provider, - credentials=args["credentials"], - credential_id=args["credential_id"], - credential_name=args["name"], + credentials=args.credentials, + credential_id=args.credential_id, + credential_name=args.name, ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) return {"result": "success"} - @console_ns.expect(parser_delete_cred) + @console_ns.expect(console_ns.models[ParserCredentialDelete.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def delete(self, provider: str): _, current_tenant_id = current_account_with_tenant() - args = parser_delete_cred.parse_args() + payload = console_ns.payload or {} + args = ParserCredentialDelete.model_validate(payload) model_provider_service = ModelProviderService() model_provider_service.remove_provider_credential( - tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"] + tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id ) return {"result": "success"}, 204 -parser_switch = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/model-providers//credentials/switch") class ModelProviderCredentialSwitchApi(Resource): - @console_ns.expect(parser_switch) + @console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): _, current_tenant_id = current_account_with_tenant() - args = parser_switch.parse_args() + payload = console_ns.payload or {} + args = ParserCredentialSwitch.model_validate(payload) service = ModelProviderService() service.switch_active_provider_credential( tenant_id=current_tenant_id, provider=provider, - credential_id=args["credential_id"], + credential_id=args.credential_id, ) return {"result": "success"} -parser_validate = reqparse.RequestParser().add_argument( - "credentials", type=dict, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/model-providers//credentials/validate") class ModelProviderValidateApi(Resource): - @console_ns.expect(parser_validate) + @console_ns.expect(console_ns.models[ParserCredentialValidate.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider: str): _, current_tenant_id = current_account_with_tenant() - args = parser_validate.parse_args() + payload = console_ns.payload or {} + args = ParserCredentialValidate.model_validate(payload) tenant_id = current_tenant_id @@ -195,7 +263,7 @@ class ModelProviderValidateApi(Resource): try: model_provider_service.validate_provider_credentials( - tenant_id=tenant_id, provider=provider, credentials=args["credentials"] + tenant_id=tenant_id, provider=provider, credentials=args.credentials ) except CredentialsValidateFailedError as ex: result = False @@ -228,19 +296,9 @@ class ModelProviderIconApi(Resource): return send_file(io.BytesIO(icon), mimetype=mimetype) -parser_preferred = reqparse.RequestParser().add_argument( - "preferred_provider_type", - type=str, - required=True, - nullable=False, - choices=["system", "custom"], - location="json", -) - - @console_ns.route("/workspaces/current/model-providers//preferred-provider-type") class PreferredProviderTypeUpdateApi(Resource): - @console_ns.expect(parser_preferred) + @console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -250,11 +308,12 @@ class PreferredProviderTypeUpdateApi(Resource): tenant_id = current_tenant_id - args = parser_preferred.parse_args() + payload = console_ns.payload or {} + args = ParserPreferredProviderType.model_validate(payload) model_provider_service = ModelProviderService() model_provider_service.switch_preferred_provider( - tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] + tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type ) return {"result": "success"} diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 2aca73806a..8e402b4bae 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,52 +1,172 @@ import logging +from typing import Any -from flask_restx import Resource, reqparse +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder -from libs.helper import StrLen, uuid_value +from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -parser_get_default = reqparse.RequestParser().add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="args", +class ParserGetDefault(BaseModel): + model_type: ModelType + + +class ParserPostDefault(BaseModel): + class Inner(BaseModel): + model_type: ModelType + model: str + provider: str | None = None + + model_settings: list[Inner] + + +console_ns.schema_model( + ParserGetDefault.__name__, ParserGetDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) -parser_post_default = reqparse.RequestParser().add_argument( - "model_settings", type=list, required=True, nullable=False, location="json" + +console_ns.schema_model( + ParserPostDefault.__name__, ParserPostDefault.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + + +class ParserDeleteModels(BaseModel): + model: str + model_type: ModelType + + +console_ns.schema_model( + ParserDeleteModels.__name__, ParserDeleteModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + + +class LoadBalancingPayload(BaseModel): + configs: list[dict[str, Any]] | None = None + enabled: bool | None = None + + +class ParserPostModels(BaseModel): + model: str + model_type: ModelType + load_balancing: LoadBalancingPayload | None = None + config_from: str | None = None + credential_id: str | None = None + + @field_validator("credential_id") + @classmethod + def validate_credential_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ParserGetCredentials(BaseModel): + model: str + model_type: ModelType + config_from: str | None = None + credential_id: str | None = None + + @field_validator("credential_id") + @classmethod + def validate_get_credential_id(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + +class ParserCredentialBase(BaseModel): + model: str + model_type: ModelType + + +class ParserCreateCredential(ParserCredentialBase): + name: str | None = Field(default=None, max_length=30) + credentials: dict[str, Any] + + +class ParserUpdateCredential(ParserCredentialBase): + credential_id: str + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + + @field_validator("credential_id") + @classmethod + def validate_update_credential_id(cls, value: str) -> str: + return uuid_value(value) + + +class ParserDeleteCredential(ParserCredentialBase): + credential_id: str + + @field_validator("credential_id") + @classmethod + def validate_delete_credential_id(cls, value: str) -> str: + return uuid_value(value) + + +class ParserParameter(BaseModel): + model: str + + +console_ns.schema_model( + ParserPostModels.__name__, ParserPostModels.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserGetCredentials.__name__, + ParserGetCredentials.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserCreateCredential.__name__, + ParserCreateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserUpdateCredential.__name__, + ParserUpdateCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserDeleteCredential.__name__, + ParserDeleteCredential.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserParameter.__name__, ParserParameter.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @console_ns.route("/workspaces/current/default-model") class DefaultModelApi(Resource): - @console_ns.expect(parser_get_default) + @console_ns.expect(console_ns.models[ParserGetDefault.__name__], validate=True) @setup_required @login_required @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = parser_get_default.parse_args() + args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( - tenant_id=tenant_id, model_type=args["model_type"] + tenant_id=tenant_id, model_type=args.model_type ) return jsonable_encoder({"data": default_model_entity}) - @console_ns.expect(parser_post_default) + @console_ns.expect(console_ns.models[ParserPostDefault.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -54,66 +174,31 @@ class DefaultModelApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - args = parser_post_default.parse_args() + args = ParserPostDefault.model_validate(console_ns.payload) model_provider_service = ModelProviderService() - model_settings = args["model_settings"] + model_settings = args.model_settings for model_setting in model_settings: - if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: - raise ValueError("invalid model type") - - if "provider" not in model_setting: + if model_setting.provider is None: continue - if "model" not in model_setting: - raise ValueError("invalid model") - try: model_provider_service.update_default_model_of_model_type( tenant_id=tenant_id, - model_type=model_setting["model_type"], - provider=model_setting["provider"], - model=model_setting["model"], + model_type=model_setting.model_type, + provider=model_setting.provider, + model=model_setting.model, ) except Exception as ex: logger.exception( "Failed to update default model, model type: %s, model: %s", - model_setting["model_type"], - model_setting.get("model"), + model_setting.model_type, + model_setting.model, ) raise ex return {"result": "success"} -parser_post_models = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") - .add_argument("config_from", type=str, required=False, nullable=True, location="json") - .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") -) -parser_delete_models = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) -) - - @console_ns.route("/workspaces/current/model-providers//models") class ModelProviderModelApi(Resource): @setup_required @@ -127,7 +212,7 @@ class ModelProviderModelApi(Resource): return jsonable_encoder({"data": models}) - @console_ns.expect(parser_post_models) + @console_ns.expect(console_ns.models[ParserPostModels.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -135,45 +220,45 @@ class ModelProviderModelApi(Resource): def post(self, provider: str): # To save the model's load balance configs _, tenant_id = current_account_with_tenant() - args = parser_post_models.parse_args() + args = ParserPostModels.model_validate(console_ns.payload) - if args.get("config_from", "") == "custom-model": - if not args.get("credential_id"): + if args.config_from == "custom-model": + if not args.credential_id: raise ValueError("credential_id is required when configuring a custom-model") service = ModelProviderService() service.switch_active_custom_model_credential( tenant_id=tenant_id, provider=provider, - model_type=args["model_type"], - model=args["model"], - credential_id=args["credential_id"], + model_type=args.model_type, + model=args.model, + credential_id=args.credential_id, ) model_load_balancing_service = ModelLoadBalancingService() - if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]: + if args.load_balancing and args.load_balancing.configs: # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - configs=args["load_balancing"]["configs"], - config_from=args.get("config_from", ""), + model=args.model, + model_type=args.model_type, + configs=args.load_balancing.configs, + config_from=args.config_from or "", ) - if args.get("load_balancing", {}).get("enabled"): + if args.load_balancing.enabled: model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type ) else: model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type ) return {"result": "success"}, 200 - @console_ns.expect(parser_delete_models) + @console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True) @setup_required @login_required @is_admin_or_owner_required @@ -181,113 +266,53 @@ class ModelProviderModelApi(Resource): def delete(self, provider: str): _, tenant_id = current_account_with_tenant() - args = parser_delete_models.parse_args() + args = ParserDeleteModels.model_validate(console_ns.payload) model_provider_service = ModelProviderService() model_provider_service.remove_model( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type ) return {"result": "success"}, 204 -parser_get_credentials = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="args") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="args", - ) - .add_argument("config_from", type=str, required=False, nullable=True, location="args") - .add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") -) - - -parser_post_cred = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") -) -parser_put_cred = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) -parser_delete_cred = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/model-providers//models/credentials") class ModelProviderModelCredentialApi(Resource): - @console_ns.expect(parser_get_credentials) + @console_ns.expect(console_ns.models[ParserGetCredentials.__name__]) @setup_required @login_required @account_initialization_required def get(self, provider: str): _, tenant_id = current_account_with_tenant() - args = parser_get_credentials.parse_args() + args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore model_provider_service = ModelProviderService() current_credential = model_provider_service.get_model_credential( tenant_id=tenant_id, provider=provider, - model_type=args["model_type"], - model=args["model"], - credential_id=args.get("credential_id"), + model_type=args.model_type, + model=args.model, + credential_id=args.credential_id, ) model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - config_from=args.get("config_from", ""), + model=args.model, + model_type=args.model_type, + config_from=args.config_from or "", ) - if args.get("config_from", "") == "predefined-model": + if args.config_from == "predefined-model": available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( tenant_id=tenant_id, provider_name=provider ) else: - model_type = ModelType.value_of(args["model_type"]).to_origin_model_type() + model_type = args.model_type available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"] + tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model ) return jsonable_encoder( @@ -304,7 +329,7 @@ class ModelProviderModelCredentialApi(Resource): } ) - @console_ns.expect(parser_post_cred) + @console_ns.expect(console_ns.models[ParserCreateCredential.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -312,7 +337,7 @@ class ModelProviderModelCredentialApi(Resource): def post(self, provider: str): _, tenant_id = current_account_with_tenant() - args = parser_post_cred.parse_args() + args = ParserCreateCredential.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -320,30 +345,30 @@ class ModelProviderModelCredentialApi(Resource): model_provider_service.create_model_credential( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], - credential_name=args["name"], + model=args.model, + model_type=args.model_type, + credentials=args.credentials, + credential_name=args.name, ) except CredentialsValidateFailedError as ex: logger.exception( "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", tenant_id, - args.get("model"), - args.get("model_type"), + args.model, + args.model_type, ) raise ValueError(str(ex)) return {"result": "success"}, 201 - @console_ns.expect(parser_put_cred) + @console_ns.expect(console_ns.models[ParserUpdateCredential.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def put(self, provider: str): _, current_tenant_id = current_account_with_tenant() - args = parser_put_cred.parse_args() + args = ParserUpdateCredential.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -351,106 +376,87 @@ class ModelProviderModelCredentialApi(Resource): model_provider_service.update_model_credential( tenant_id=current_tenant_id, provider=provider, - model_type=args["model_type"], - model=args["model"], - credentials=args["credentials"], - credential_id=args["credential_id"], - credential_name=args["name"], + model_type=args.model_type, + model=args.model, + credentials=args.credentials, + credential_id=args.credential_id, + credential_name=args.name, ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) return {"result": "success"} - @console_ns.expect(parser_delete_cred) + @console_ns.expect(console_ns.models[ParserDeleteCredential.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def delete(self, provider: str): _, current_tenant_id = current_account_with_tenant() - args = parser_delete_cred.parse_args() + args = ParserDeleteCredential.model_validate(console_ns.payload) model_provider_service = ModelProviderService() model_provider_service.remove_model_credential( tenant_id=current_tenant_id, provider=provider, - model_type=args["model_type"], - model=args["model"], - credential_id=args["credential_id"], + model_type=args.model_type, + model=args.model, + credential_id=args.credential_id, ) return {"result": "success"}, 204 -parser_switch = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") +class ParserSwitch(BaseModel): + model: str + model_type: ModelType + credential_id: str + + +console_ns.schema_model( + ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @console_ns.route("/workspaces/current/model-providers//models/credentials/switch") class ModelProviderModelCredentialSwitchApi(Resource): - @console_ns.expect(parser_switch) + @console_ns.expect(console_ns.models[ParserSwitch.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): _, current_tenant_id = current_account_with_tenant() - - args = parser_switch.parse_args() + args = ParserSwitch.model_validate(console_ns.payload) service = ModelProviderService() service.add_model_credential_to_model_list( tenant_id=current_tenant_id, provider=provider, - model_type=args["model_type"], - model=args["model"], - credential_id=args["credential_id"], + model_type=args.model_type, + model=args.model, + credential_id=args.credential_id, ) return {"result": "success"} -parser_model_enable_disable = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) -) - - @console_ns.route( "/workspaces/current/model-providers//models/enable", endpoint="model-provider-model-enable" ) class ModelProviderModelEnableApi(Resource): - @console_ns.expect(parser_model_enable_disable) + @console_ns.expect(console_ns.models[ParserDeleteModels.__name__]) @setup_required @login_required @account_initialization_required def patch(self, provider: str): _, tenant_id = current_account_with_tenant() - args = parser_model_enable_disable.parse_args() + args = ParserDeleteModels.model_validate(console_ns.payload) model_provider_service = ModelProviderService() model_provider_service.enable_model( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type ) return {"result": "success"} @@ -460,48 +466,43 @@ class ModelProviderModelEnableApi(Resource): "/workspaces/current/model-providers//models/disable", endpoint="model-provider-model-disable" ) class ModelProviderModelDisableApi(Resource): - @console_ns.expect(parser_model_enable_disable) + @console_ns.expect(console_ns.models[ParserDeleteModels.__name__]) @setup_required @login_required @account_initialization_required def patch(self, provider: str): _, tenant_id = current_account_with_tenant() - args = parser_model_enable_disable.parse_args() + args = ParserDeleteModels.model_validate(console_ns.payload) model_provider_service = ModelProviderService() model_provider_service.disable_model( - tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] + tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type ) return {"result": "success"} -parser_validate = ( - reqparse.RequestParser() - .add_argument("model", type=str, required=True, nullable=False, location="json") - .add_argument( - "model_type", - type=str, - required=True, - nullable=False, - choices=[mt.value for mt in ModelType], - location="json", - ) - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") +class ParserValidate(BaseModel): + model: str + model_type: ModelType + credentials: dict + + +console_ns.schema_model( + ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @console_ns.route("/workspaces/current/model-providers//models/credentials/validate") class ModelProviderModelValidateApi(Resource): - @console_ns.expect(parser_validate) + @console_ns.expect(console_ns.models[ParserValidate.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider: str): _, tenant_id = current_account_with_tenant() - - args = parser_validate.parse_args() + args = ParserValidate.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -512,9 +513,9 @@ class ModelProviderModelValidateApi(Resource): model_provider_service.validate_model_credentials( tenant_id=tenant_id, provider=provider, - model=args["model"], - model_type=args["model_type"], - credentials=args["credentials"], + model=args.model, + model_type=args.model_type, + credentials=args.credentials, ) except CredentialsValidateFailedError as ex: result = False @@ -528,24 +529,19 @@ class ModelProviderModelValidateApi(Resource): return response -parser_parameter = reqparse.RequestParser().add_argument( - "model", type=str, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/model-providers//models/parameter-rules") class ModelProviderModelParameterRuleApi(Resource): - @console_ns.expect(parser_parameter) + @console_ns.expect(console_ns.models[ParserParameter.__name__]) @setup_required @login_required @account_initialization_required def get(self, provider: str): - args = parser_parameter.parse_args() + args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( - tenant_id=tenant_id, provider=provider, model=args["model"] + tenant_id=tenant_id, provider=provider, model=args.model ) return jsonable_encoder({"data": parameter_rules}) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index e3345033f8..7e08ea55f9 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -1,7 +1,9 @@ import io +from typing import Literal from flask import request, send_file -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden from configs import dify_config @@ -17,6 +19,8 @@ from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + @console_ns.route("/workspaces/current/plugin/debugging-key") class PluginDebuggingKeyApi(Resource): @@ -37,88 +41,251 @@ class PluginDebuggingKeyApi(Resource): raise ValueError(e) -parser_list = ( - reqparse.RequestParser() - .add_argument("page", type=int, required=False, location="args", default=1) - .add_argument("page_size", type=int, required=False, location="args", default=256) +class ParserList(BaseModel): + page: int = Field(default=1) + page_size: int = Field(default=256) + + +console_ns.schema_model( + ParserList.__name__, ParserList.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) ) @console_ns.route("/workspaces/current/plugin/list") class PluginListApi(Resource): - @console_ns.expect(parser_list) + @console_ns.expect(console_ns.models[ParserList.__name__]) @setup_required @login_required @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = parser_list.parse_args() + args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"]) + plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: raise ValueError(e) return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) -parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json") +class ParserLatest(BaseModel): + plugin_ids: list[str] + + +console_ns.schema_model( + ParserLatest.__name__, ParserLatest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + + +class ParserIcon(BaseModel): + tenant_id: str + filename: str + + +class ParserAsset(BaseModel): + plugin_unique_identifier: str + file_name: str + + +class ParserGithubUpload(BaseModel): + repo: str + version: str + package: str + + +class ParserPluginIdentifiers(BaseModel): + plugin_unique_identifiers: list[str] + + +class ParserGithubInstall(BaseModel): + plugin_unique_identifier: str + repo: str + version: str + package: str + + +class ParserPluginIdentifierQuery(BaseModel): + plugin_unique_identifier: str + + +class ParserTasks(BaseModel): + page: int + page_size: int + + +class ParserMarketplaceUpgrade(BaseModel): + original_plugin_unique_identifier: str + new_plugin_unique_identifier: str + + +class ParserGithubUpgrade(BaseModel): + original_plugin_unique_identifier: str + new_plugin_unique_identifier: str + repo: str + version: str + package: str + + +class ParserUninstall(BaseModel): + plugin_installation_id: str + + +class ParserPermissionChange(BaseModel): + install_permission: TenantPluginPermission.InstallPermission + debug_permission: TenantPluginPermission.DebugPermission + + +class ParserDynamicOptions(BaseModel): + plugin_id: str + provider: str + action: str + parameter: str + credential_id: str | None = None + provider_type: Literal["tool", "trigger"] + + +class PluginPermissionSettingsPayload(BaseModel): + install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE + debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE + + +class PluginAutoUpgradeSettingsPayload(BaseModel): + strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = ( + TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY + ) + upgrade_time_of_day: int = 0 + upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE + exclude_plugins: list[str] = Field(default_factory=list) + include_plugins: list[str] = Field(default_factory=list) + + +class ParserPreferencesChange(BaseModel): + permission: PluginPermissionSettingsPayload + auto_upgrade: PluginAutoUpgradeSettingsPayload + + +class ParserExcludePlugin(BaseModel): + plugin_id: str + + +class ParserReadme(BaseModel): + plugin_unique_identifier: str + language: str = Field(default="en-US") + + +console_ns.schema_model( + ParserIcon.__name__, ParserIcon.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserAsset.__name__, ParserAsset.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserGithubUpload.__name__, ParserGithubUpload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserPluginIdentifiers.__name__, + ParserPluginIdentifiers.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserGithubInstall.__name__, ParserGithubInstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserPluginIdentifierQuery.__name__, + ParserPluginIdentifierQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserTasks.__name__, ParserTasks.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserMarketplaceUpgrade.__name__, + ParserMarketplaceUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserGithubUpgrade.__name__, ParserGithubUpgrade.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserUninstall.__name__, ParserUninstall.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + ParserPermissionChange.__name__, + ParserPermissionChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserDynamicOptions.__name__, + ParserDynamicOptions.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserPreferencesChange.__name__, + ParserPreferencesChange.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserExcludePlugin.__name__, + ParserExcludePlugin.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + ParserReadme.__name__, ParserReadme.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) @console_ns.route("/workspaces/current/plugin/list/latest-versions") class PluginListLatestVersionsApi(Resource): - @console_ns.expect(parser_latest) + @console_ns.expect(console_ns.models[ParserLatest.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_latest.parse_args() + args = ParserLatest.model_validate(console_ns.payload) try: - versions = PluginService.list_latest_versions(args["plugin_ids"]) + versions = PluginService.list_latest_versions(args.plugin_ids) except PluginDaemonClientSideError as e: raise ValueError(e) return jsonable_encoder({"versions": versions}) -parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json") - - @console_ns.route("/workspaces/current/plugin/list/installations/ids") class PluginListInstallationsFromIdsApi(Resource): - @console_ns.expect(parser_ids) + @console_ns.expect(console_ns.models[ParserLatest.__name__]) @setup_required @login_required @account_initialization_required def post(self): _, tenant_id = current_account_with_tenant() - args = parser_ids.parse_args() + args = ParserLatest.model_validate(console_ns.payload) try: - plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"]) + plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids) except PluginDaemonClientSideError as e: raise ValueError(e) return jsonable_encoder({"plugins": plugins}) -parser_icon = ( - reqparse.RequestParser() - .add_argument("tenant_id", type=str, required=True, location="args") - .add_argument("filename", type=str, required=True, location="args") -) - - @console_ns.route("/workspaces/current/plugin/icon") class PluginIconApi(Resource): - @console_ns.expect(parser_icon) + @console_ns.expect(console_ns.models[ParserIcon.__name__]) @setup_required def get(self): - args = parser_icon.parse_args() + args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"]) + icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) except PluginDaemonClientSideError as e: raise ValueError(e) @@ -128,20 +295,16 @@ class PluginIconApi(Resource): @console_ns.route("/workspaces/current/plugin/asset") class PluginAssetApi(Resource): + @console_ns.expect(console_ns.models[ParserAsset.__name__]) @setup_required @login_required @account_initialization_required def get(self): - req = ( - reqparse.RequestParser() - .add_argument("plugin_unique_identifier", type=str, required=True, location="args") - .add_argument("file_name", type=str, required=True, location="args") - ) - args = req.parse_args() + args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore _, tenant_id = current_account_with_tenant() try: - binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"]) + binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name) return send_file(io.BytesIO(binary), mimetype="application/octet-stream") except PluginDaemonClientSideError as e: raise ValueError(e) @@ -171,17 +334,9 @@ class PluginUploadFromPkgApi(Resource): return jsonable_encoder(response) -parser_github = ( - reqparse.RequestParser() - .add_argument("repo", type=str, required=True, location="json") - .add_argument("version", type=str, required=True, location="json") - .add_argument("package", type=str, required=True, location="json") -) - - @console_ns.route("/workspaces/current/plugin/upload/github") class PluginUploadFromGithubApi(Resource): - @console_ns.expect(parser_github) + @console_ns.expect(console_ns.models[ParserGithubUpload.__name__]) @setup_required @login_required @account_initialization_required @@ -189,10 +344,10 @@ class PluginUploadFromGithubApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - args = parser_github.parse_args() + args = ParserGithubUpload.model_validate(console_ns.payload) try: - response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"]) + response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package) except PluginDaemonClientSideError as e: raise ValueError(e) @@ -223,47 +378,28 @@ class PluginUploadFromBundleApi(Resource): return jsonable_encoder(response) -parser_pkg = reqparse.RequestParser().add_argument( - "plugin_unique_identifiers", type=list, required=True, location="json" -) - - @console_ns.route("/workspaces/current/plugin/install/pkg") class PluginInstallFromPkgApi(Resource): - @console_ns.expect(parser_pkg) + @console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__]) @setup_required @login_required @account_initialization_required @plugin_permission_required(install_required=True) def post(self): _, tenant_id = current_account_with_tenant() - args = parser_pkg.parse_args() - - # check if all plugin_unique_identifiers are valid string - for plugin_unique_identifier in args["plugin_unique_identifiers"]: - if not isinstance(plugin_unique_identifier, str): - raise ValueError("Invalid plugin unique identifier") + args = ParserPluginIdentifiers.model_validate(console_ns.payload) try: - response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"]) + response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: raise ValueError(e) return jsonable_encoder(response) -parser_githubapi = ( - reqparse.RequestParser() - .add_argument("repo", type=str, required=True, location="json") - .add_argument("version", type=str, required=True, location="json") - .add_argument("package", type=str, required=True, location="json") - .add_argument("plugin_unique_identifier", type=str, required=True, location="json") -) - - @console_ns.route("/workspaces/current/plugin/install/github") class PluginInstallFromGithubApi(Resource): - @console_ns.expect(parser_githubapi) + @console_ns.expect(console_ns.models[ParserGithubInstall.__name__]) @setup_required @login_required @account_initialization_required @@ -271,15 +407,15 @@ class PluginInstallFromGithubApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - args = parser_githubapi.parse_args() + args = ParserGithubInstall.model_validate(console_ns.payload) try: response = PluginService.install_from_github( tenant_id, - args["plugin_unique_identifier"], - args["repo"], - args["version"], - args["package"], + args.plugin_unique_identifier, + args.repo, + args.version, + args.package, ) except PluginDaemonClientSideError as e: raise ValueError(e) @@ -287,14 +423,9 @@ class PluginInstallFromGithubApi(Resource): return jsonable_encoder(response) -parser_marketplace = reqparse.RequestParser().add_argument( - "plugin_unique_identifiers", type=list, required=True, location="json" -) - - @console_ns.route("/workspaces/current/plugin/install/marketplace") class PluginInstallFromMarketplaceApi(Resource): - @console_ns.expect(parser_marketplace) + @console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__]) @setup_required @login_required @account_initialization_required @@ -302,43 +433,33 @@ class PluginInstallFromMarketplaceApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - args = parser_marketplace.parse_args() - - # check if all plugin_unique_identifiers are valid string - for plugin_unique_identifier in args["plugin_unique_identifiers"]: - if not isinstance(plugin_unique_identifier, str): - raise ValueError("Invalid plugin unique identifier") + args = ParserPluginIdentifiers.model_validate(console_ns.payload) try: - response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"]) + response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: raise ValueError(e) return jsonable_encoder(response) -parser_pkgapi = reqparse.RequestParser().add_argument( - "plugin_unique_identifier", type=str, required=True, location="args" -) - - @console_ns.route("/workspaces/current/plugin/marketplace/pkg") class PluginFetchMarketplacePkgApi(Resource): - @console_ns.expect(parser_pkgapi) + @console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__]) @setup_required @login_required @account_initialization_required @plugin_permission_required(install_required=True) def get(self): _, tenant_id = current_account_with_tenant() - args = parser_pkgapi.parse_args() + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: return jsonable_encoder( { "manifest": PluginService.fetch_marketplace_pkg( tenant_id, - args["plugin_unique_identifier"], + args.plugin_unique_identifier, ) } ) @@ -346,14 +467,9 @@ class PluginFetchMarketplacePkgApi(Resource): raise ValueError(e) -parser_fetch = reqparse.RequestParser().add_argument( - "plugin_unique_identifier", type=str, required=True, location="args" -) - - @console_ns.route("/workspaces/current/plugin/fetch-manifest") class PluginFetchManifestApi(Resource): - @console_ns.expect(parser_fetch) + @console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__]) @setup_required @login_required @account_initialization_required @@ -361,30 +477,19 @@ class PluginFetchManifestApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = parser_fetch.parse_args() + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore try: return jsonable_encoder( - { - "manifest": PluginService.fetch_plugin_manifest( - tenant_id, args["plugin_unique_identifier"] - ).model_dump() - } + {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()} ) except PluginDaemonClientSideError as e: raise ValueError(e) -parser_tasks = ( - reqparse.RequestParser() - .add_argument("page", type=int, required=True, location="args") - .add_argument("page_size", type=int, required=True, location="args") -) - - @console_ns.route("/workspaces/current/plugin/tasks") class PluginFetchInstallTasksApi(Resource): - @console_ns.expect(parser_tasks) + @console_ns.expect(console_ns.models[ParserTasks.__name__]) @setup_required @login_required @account_initialization_required @@ -392,12 +497,10 @@ class PluginFetchInstallTasksApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = parser_tasks.parse_args() + args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore try: - return jsonable_encoder( - {"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])} - ) + return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) except PluginDaemonClientSideError as e: raise ValueError(e) @@ -462,16 +565,9 @@ class PluginDeleteInstallTaskItemApi(Resource): raise ValueError(e) -parser_marketplace_api = ( - reqparse.RequestParser() - .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") - .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") -) - - @console_ns.route("/workspaces/current/plugin/upgrade/marketplace") class PluginUpgradeFromMarketplaceApi(Resource): - @console_ns.expect(parser_marketplace_api) + @console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__]) @setup_required @login_required @account_initialization_required @@ -479,31 +575,21 @@ class PluginUpgradeFromMarketplaceApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - args = parser_marketplace_api.parse_args() + args = ParserMarketplaceUpgrade.model_validate(console_ns.payload) try: return jsonable_encoder( PluginService.upgrade_plugin_with_marketplace( - tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"] + tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier ) ) except PluginDaemonClientSideError as e: raise ValueError(e) -parser_github_post = ( - reqparse.RequestParser() - .add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") - .add_argument("new_plugin_unique_identifier", type=str, required=True, location="json") - .add_argument("repo", type=str, required=True, location="json") - .add_argument("version", type=str, required=True, location="json") - .add_argument("package", type=str, required=True, location="json") -) - - @console_ns.route("/workspaces/current/plugin/upgrade/github") class PluginUpgradeFromGithubApi(Resource): - @console_ns.expect(parser_github_post) + @console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__]) @setup_required @login_required @account_initialization_required @@ -511,56 +597,44 @@ class PluginUpgradeFromGithubApi(Resource): def post(self): _, tenant_id = current_account_with_tenant() - args = parser_github_post.parse_args() + args = ParserGithubUpgrade.model_validate(console_ns.payload) try: return jsonable_encoder( PluginService.upgrade_plugin_with_github( tenant_id, - args["original_plugin_unique_identifier"], - args["new_plugin_unique_identifier"], - args["repo"], - args["version"], - args["package"], + args.original_plugin_unique_identifier, + args.new_plugin_unique_identifier, + args.repo, + args.version, + args.package, ) ) except PluginDaemonClientSideError as e: raise ValueError(e) -parser_uninstall = reqparse.RequestParser().add_argument( - "plugin_installation_id", type=str, required=True, location="json" -) - - @console_ns.route("/workspaces/current/plugin/uninstall") class PluginUninstallApi(Resource): - @console_ns.expect(parser_uninstall) + @console_ns.expect(console_ns.models[ParserUninstall.__name__]) @setup_required @login_required @account_initialization_required @plugin_permission_required(install_required=True) def post(self): - args = parser_uninstall.parse_args() + args = ParserUninstall.model_validate(console_ns.payload) _, tenant_id = current_account_with_tenant() try: - return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} + return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)} except PluginDaemonClientSideError as e: raise ValueError(e) -parser_change_post = ( - reqparse.RequestParser() - .add_argument("install_permission", type=str, required=True, location="json") - .add_argument("debug_permission", type=str, required=True, location="json") -) - - @console_ns.route("/workspaces/current/plugin/permission/change") class PluginChangePermissionApi(Resource): - @console_ns.expect(parser_change_post) + @console_ns.expect(console_ns.models[ParserPermissionChange.__name__]) @setup_required @login_required @account_initialization_required @@ -570,14 +644,15 @@ class PluginChangePermissionApi(Resource): if not user.is_admin_or_owner: raise Forbidden() - args = parser_change_post.parse_args() - - install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) - debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) + args = ParserPermissionChange.model_validate(console_ns.payload) tenant_id = current_tenant_id - return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} + return { + "success": PluginPermissionService.change_permission( + tenant_id, args.install_permission, args.debug_permission + ) + } @console_ns.route("/workspaces/current/plugin/permission/fetch") @@ -605,20 +680,9 @@ class PluginFetchPermissionApi(Resource): ) -parser_dynamic = ( - reqparse.RequestParser() - .add_argument("plugin_id", type=str, required=True, location="args") - .add_argument("provider", type=str, required=True, location="args") - .add_argument("action", type=str, required=True, location="args") - .add_argument("parameter", type=str, required=True, location="args") - .add_argument("credential_id", type=str, required=False, location="args") - .add_argument("provider_type", type=str, required=True, location="args") -) - - @console_ns.route("/workspaces/current/plugin/parameters/dynamic-options") class PluginFetchDynamicSelectOptionsApi(Resource): - @console_ns.expect(parser_dynamic) + @console_ns.expect(console_ns.models[ParserDynamicOptions.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -627,18 +691,18 @@ class PluginFetchDynamicSelectOptionsApi(Resource): current_user, tenant_id = current_account_with_tenant() user_id = current_user.id - args = parser_dynamic.parse_args() + args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore try: options = PluginParameterService.get_dynamic_select_options( tenant_id=tenant_id, user_id=user_id, - plugin_id=args["plugin_id"], - provider=args["provider"], - action=args["action"], - parameter=args["parameter"], - credential_id=args["credential_id"], - provider_type=args["provider_type"], + plugin_id=args.plugin_id, + provider=args.provider, + action=args.action, + parameter=args.parameter, + credential_id=args.credential_id, + provider_type=args.provider_type, ) except PluginDaemonClientSideError as e: raise ValueError(e) @@ -646,16 +710,9 @@ class PluginFetchDynamicSelectOptionsApi(Resource): return jsonable_encoder({"options": options}) -parser_change = ( - reqparse.RequestParser() - .add_argument("permission", type=dict, required=True, location="json") - .add_argument("auto_upgrade", type=dict, required=True, location="json") -) - - @console_ns.route("/workspaces/current/plugin/preferences/change") class PluginChangePreferencesApi(Resource): - @console_ns.expect(parser_change) + @console_ns.expect(console_ns.models[ParserPreferencesChange.__name__]) @setup_required @login_required @account_initialization_required @@ -664,22 +721,20 @@ class PluginChangePreferencesApi(Resource): if not user.is_admin_or_owner: raise Forbidden() - args = parser_change.parse_args() + args = ParserPreferencesChange.model_validate(console_ns.payload) - permission = args["permission"] + permission = args.permission - install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone")) - debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone")) + install_permission = permission.install_permission + debug_permission = permission.debug_permission - auto_upgrade = args["auto_upgrade"] + auto_upgrade = args.auto_upgrade - strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting( - auto_upgrade.get("strategy_setting", "fix_only") - ) - upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0) - upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude")) - exclude_plugins = auto_upgrade.get("exclude_plugins", []) - include_plugins = auto_upgrade.get("include_plugins", []) + strategy_setting = auto_upgrade.strategy_setting + upgrade_time_of_day = auto_upgrade.upgrade_time_of_day + upgrade_mode = auto_upgrade.upgrade_mode + exclude_plugins = auto_upgrade.exclude_plugins + include_plugins = auto_upgrade.include_plugins # set permission set_permission_result = PluginPermissionService.change_permission( @@ -744,12 +799,9 @@ class PluginFetchPreferencesApi(Resource): return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) -parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json") - - @console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude") class PluginAutoUpgradeExcludePluginApi(Resource): - @console_ns.expect(parser_exclude) + @console_ns.expect(console_ns.models[ParserExcludePlugin.__name__]) @setup_required @login_required @account_initialization_required @@ -757,28 +809,20 @@ class PluginAutoUpgradeExcludePluginApi(Resource): # exclude one single plugin _, tenant_id = current_account_with_tenant() - args = parser_exclude.parse_args() + args = ParserExcludePlugin.model_validate(console_ns.payload) - return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) + return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)}) @console_ns.route("/workspaces/current/plugin/readme") class PluginReadmeApi(Resource): + @console_ns.expect(console_ns.models[ParserReadme.__name__]) @setup_required @login_required @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("plugin_unique_identifier", type=str, required=True, location="args") - .add_argument("language", type=str, required=False, location="args") - ) - args = parser.parse_args() + args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore return jsonable_encoder( - { - "readme": PluginService.fetch_plugin_readme( - tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US") - ) - } + {"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)} ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 37c7dc3040..9b76cb7a9c 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,7 +1,8 @@ import logging from flask import request -from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, marshal, marshal_with +from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import Unauthorized @@ -32,6 +33,45 @@ from services.file_service import FileService from services.workspace_service import WorkspaceService logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkspaceListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=20, ge=1, le=100) + + +class SwitchWorkspacePayload(BaseModel): + tenant_id: str + + +class WorkspaceCustomConfigPayload(BaseModel): + remove_webapp_brand: bool | None = None + replace_webapp_logo: str | None = None + + +class WorkspaceInfoPayload(BaseModel): + name: str + + +console_ns.schema_model( + WorkspaceListQuery.__name__, WorkspaceListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) +) + +console_ns.schema_model( + SwitchWorkspacePayload.__name__, + SwitchWorkspacePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + WorkspaceCustomConfigPayload.__name__, + WorkspaceCustomConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + WorkspaceInfoPayload.__name__, + WorkspaceInfoPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) provider_fields = { @@ -95,18 +135,15 @@ class TenantListApi(Resource): @console_ns.route("/all-workspaces") class WorkspaceListApi(Resource): + @console_ns.expect(console_ns.models[WorkspaceListQuery.__name__]) @setup_required @admin_required def get(self): - parser = ( - reqparse.RequestParser() - .add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") - .add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") - ) - args = parser.parse_args() + payload = request.args.to_dict(flat=True) # type: ignore + args = WorkspaceListQuery.model_validate(payload) stmt = select(Tenant).order_by(Tenant.created_at.desc()) - tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False) + tenants = db.paginate(select=stmt, page=args.page, per_page=args.limit, error_out=False) has_more = False if tenants.has_next: @@ -115,8 +152,8 @@ class WorkspaceListApi(Resource): return { "data": marshal(tenants.items, workspace_fields), "has_more": has_more, - "limit": args["limit"], - "page": args["page"], + "limit": args.limit, + "page": args.page, "total": tenants.total, }, 200 @@ -150,26 +187,24 @@ class TenantApi(Resource): return WorkspaceService.get_tenant_info(tenant), 200 -parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json") - - @console_ns.route("/workspaces/switch") class SwitchWorkspaceApi(Resource): - @console_ns.expect(parser_switch) + @console_ns.expect(console_ns.models[SwitchWorkspacePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): current_user, _ = current_account_with_tenant() - args = parser_switch.parse_args() + payload = console_ns.payload or {} + args = SwitchWorkspacePayload.model_validate(payload) # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args["tenant_id"]) + TenantService.switch_tenant(current_user, args.tenant_id) except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant if new_tenant is None: raise ValueError("Tenant not found") @@ -178,24 +213,21 @@ class SwitchWorkspaceApi(Resource): @console_ns.route("/workspaces/custom-config") class CustomConfigWorkspaceApi(Resource): + @console_ns.expect(console_ns.models[WorkspaceCustomConfigPayload.__name__]) @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): _, current_tenant_id = current_account_with_tenant() - parser = ( - reqparse.RequestParser() - .add_argument("remove_webapp_brand", type=bool, location="json") - .add_argument("replace_webapp_logo", type=str, location="json") - ) - args = parser.parse_args() + payload = console_ns.payload or {} + args = WorkspaceCustomConfigPayload.model_validate(payload) tenant = db.get_or_404(Tenant, current_tenant_id) custom_config_dict = { - "remove_webapp_brand": args["remove_webapp_brand"], - "replace_webapp_logo": args["replace_webapp_logo"] - if args["replace_webapp_logo"] is not None + "remove_webapp_brand": args.remove_webapp_brand, + "replace_webapp_logo": args.replace_webapp_logo + if args.replace_webapp_logo is not None else tenant.custom_config_dict.get("replace_webapp_logo"), } @@ -245,24 +277,22 @@ class WebappLogoWorkspaceApi(Resource): return {"id": upload_file.id}, 201 -parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json") - - @console_ns.route("/workspaces/info") class WorkspaceInfoApi(Resource): - @console_ns.expect(parser_info) + @console_ns.expect(console_ns.models[WorkspaceInfoPayload.__name__]) @setup_required @login_required @account_initialization_required # Change workspace name def post(self): _, current_tenant_id = current_account_with_tenant() - args = parser_info.parse_args() + payload = console_ns.payload or {} + args = WorkspaceInfoPayload.model_validate(payload) if not current_tenant_id: raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_tenant_id) - tenant.name = args["name"] + tenant.name = args.name db.session.commit() return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} diff --git a/api/services/account_service.py b/api/services/account_service.py index 13c3993fb5..ac6d1bde77 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1352,7 +1352,7 @@ class RegisterService: @classmethod def invite_new_member( - cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account | None = None + cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None ) -> str: if not inviter: raise ValueError("Inviter is required")