From da684ebfaaa53b75fa490390f6ef6e713e1554c6 Mon Sep 17 00:00:00 2001 From: Joe <1264204425@qq.com> Date: Thu, 29 Aug 2024 15:15:48 +0800 Subject: [PATCH] feat: update register logic --- api/controllers/console/auth/error.py | 6 --- api/controllers/console/auth/login.py | 56 ++++++++++++++++----------- api/libs/helper.py | 28 ++++++++++---- api/services/account_service.py | 43 ++++++++++++++++---- 4 files changed, 89 insertions(+), 44 deletions(-) diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index a13275a7da..1c85035d25 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -35,9 +35,3 @@ class EmailLoginCodeError(BaseHTTPException): error_code = "email_login_code_error" description = "Email login code is invalid or expired." code = 400 - - -class NotAllowCreateWorkspaceError(BaseHTTPException): - error_code = "workspace_not_found" - description = "Workspace not found." - code = 400 diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 350276c7bf..6e439863b8 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -12,7 +12,6 @@ from controllers.console.auth.error import ( EmailLoginCodeError, InvalidEmailError, InvalidTokenError, - NotAllowCreateWorkspaceError, ) from controllers.console.setup import setup_required from libs.helper import email, get_remote_ip @@ -33,8 +32,6 @@ class LoginApi(Resource): parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") args = parser.parse_args() - # todo: Verify the recaptcha - try: account = AccountService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError as e: @@ -63,12 +60,31 @@ class LogoutApi(Resource): return {"result": "success"} +class ResetPasswordSendEmailApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + args = parser.parse_args() + + account = AccountService.get_user_through_email(args["email"]) + if account is None: + if dify_config.ALLOW_REGISTER: + token = AccountService.send_reset_password_email(email=args["email"]) + else: + raise InvalidEmailError() + else: + token = AccountService.send_reset_password_email(account=account) + + return {"result": "success", "data": token} + + class ResetPasswordApi(Resource): @setup_required def get(self): - # parser = reqparse.RequestParser() - # parser.add_argument('email', type=email, required=True, location='json') - # args = parser.parse_args() + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + args = parser.parse_args() # import mailchimp_transactional as MailchimpTransactional # from mailchimp_transactional.api_client import ApiClientError @@ -123,9 +139,13 @@ class EmailCodeLoginSendEmailApi(Resource): account = AccountService.get_user_through_email(args["email"]) if account is None: - raise InvalidEmailError() + if dify_config.ALLOW_REGISTER: + token = AccountService.send_email_code_login_email(email=args["email"]) + else: + raise InvalidEmailError() + else: + token = AccountService.send_email_code_login_email(account=account) - token = AccountService.send_email_code_login_email(account) return {"result": "success", "data": token} @@ -153,25 +173,17 @@ class EmailCodeLoginApi(Resource): AccountService.revoke_email_code_login_token(args["token"]) account = AccountService.get_user_through_email(user_email) if account is None: - # through environment variable, control whether to allow user to register and create workspace - if dify_config.ALLOW_REGISTER: - account = AccountService.create_account( - email=user_email, name=user_email, interface_language=languages[0] - ) - else: - raise InvalidEmailError() - if dify_config.ALLOW_CREATE_WORKSPACE: - TenantService.create_owner_tenant_if_not_exist(account=account) - else: - raise NotAllowCreateWorkspaceError() + account = AccountService.create_user_through_env( + email=user_email, name=user_email, interface_language=languages[0] + ) - else: - token = AccountService.login(account, ip_address=get_remote_ip(request)) + token = AccountService.login(account, ip_address=get_remote_ip(request)) - return {"result": "success", "data": token} + return {"result": "success", "data": token} api.add_resource(LoginApi, "/login") api.add_resource(LogoutApi, "/logout") api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") +api.add_resource(ResetPasswordApi, "/reset-password") diff --git a/api/libs/helper.py b/api/libs/helper.py index a486bfc872..b39e83dc03 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -189,15 +189,25 @@ def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Resp class TokenManager: @classmethod - def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: - old_token = cls._get_current_token_for_account(account.id, token_type) - if old_token: - if isinstance(old_token, bytes): - old_token = old_token.decode("utf-8") - cls.revoke_token(old_token, token_type) + def generate_token( + cls, token_type: str, account: Optional[Account] = None, email: Optional[str] = None, + additional_data: dict = None + ) -> str: + if account is None and email is None: + raise ValueError("Account or email must be provided") + + account_id = account.id if account else None + account_email = account.email if account else email + + if account_id: + old_token = cls._get_current_token_for_account(account_id, token_type) + if old_token: + if isinstance(old_token, bytes): + old_token = old_token.decode("utf-8") + cls.revoke_token(old_token, token_type) token = str(uuid.uuid4()) - token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} + token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} if additional_data: token_data.update(additional_data) @@ -206,7 +216,9 @@ class TokenManager: expiry_time = int(expiry_hours * 60 * 60) redis_client.setex(token_key, expiry_time, json.dumps(token_data)) - cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) + if account_id: + cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) + return token @classmethod diff --git a/api/services/account_service.py b/api/services/account_service.py index 76e1cc64a8..e2f30d8ed3 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -158,6 +158,24 @@ class AccountService: db.session.commit() return account + @staticmethod + def create_user_through_env( + email: str, name: str, interface_language: str, password: Optional[str] = None + ) -> Account: + """create account""" + if dify_config.ALLOW_REGISTER: + account = AccountService.create_account( + email=email, name=name, interface_language=interface_language, password=password + ) + else: + raise Unauthorized("Register is not allowed.") + if dify_config.ALLOW_CREATE_WORKSPACE: + TenantService.create_owner_tenant_if_not_exist(account=account) + else: + raise Unauthorized("Create workspace is not allowed.") + + return account + @staticmethod def link_account_integrate(provider: str, open_id: str, account: Account) -> None: """Link account integrate""" @@ -231,13 +249,20 @@ class AccountService: return AccountService.load_user(account_id) @classmethod - def send_reset_password_email(cls, account): + def send_reset_password_email(cls, account: Optional[Account] = None, email: Optional[str] = None): if cls.reset_password_rate_limiter.is_rate_limited(account.email): raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") - token = TokenManager.generate_token(account, "reset_password") - send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) - cls.reset_password_rate_limiter.increment_rate_limit(account.email) + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data={"code": code} + ) + send_reset_password_mail_task.delay( + language=account.interface_language if account else languages[0], + to=account.email if account else email, + code=code, + ) + cls.reset_password_rate_limiter.increment_rate_limit(account.email if account else email) return token @classmethod @@ -249,12 +274,14 @@ class AccountService: return TokenManager.get_token_data(token, "reset_password") @classmethod - def send_email_code_login_email(cls, account: Account): + def send_email_code_login_email(cls, account: Optional[Account] = None, email: Optional[str] = None): code = "".join([str(random.randint(0, 9)) for _ in range(6)]) - token = TokenManager.generate_token(account, "email_code_login", {"code": code}) + token = TokenManager.generate_token( + account=account, email=email, token_type="email_code_login", additional_data={"code": code} + ) send_email_code_login_mail_task.delay( - language=account.interface_language, - to=account.email, + language=account.interface_language if account else languages[0], + to=account.email if account else email, code=code, )