feat: update register logic

This commit is contained in:
Joe 2024-08-29 15:15:48 +08:00
parent ccc0ec8178
commit da684ebfaa
4 changed files with 89 additions and 44 deletions

View File

@ -35,9 +35,3 @@ class EmailLoginCodeError(BaseHTTPException):
error_code = "email_login_code_error"
description = "Email login code is invalid or expired."
code = 400
class NotAllowCreateWorkspaceError(BaseHTTPException):
error_code = "workspace_not_found"
description = "Workspace not found."
code = 400

View File

@ -12,7 +12,6 @@ from controllers.console.auth.error import (
EmailLoginCodeError,
InvalidEmailError,
InvalidTokenError,
NotAllowCreateWorkspaceError,
)
from controllers.console.setup import setup_required
from libs.helper import email, get_remote_ip
@ -33,8 +32,6 @@ class LoginApi(Resource):
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
args = parser.parse_args()
# todo: Verify the recaptcha
try:
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError as e:
@ -63,12 +60,31 @@ class LogoutApi(Resource):
return {"result": "success"}
class ResetPasswordSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
args = parser.parse_args()
account = AccountService.get_user_through_email(args["email"])
if account is None:
if dify_config.ALLOW_REGISTER:
token = AccountService.send_reset_password_email(email=args["email"])
else:
raise InvalidEmailError()
else:
token = AccountService.send_reset_password_email(account=account)
return {"result": "success", "data": token}
class ResetPasswordApi(Resource):
@setup_required
def get(self):
# parser = reqparse.RequestParser()
# parser.add_argument('email', type=email, required=True, location='json')
# args = parser.parse_args()
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
args = parser.parse_args()
# import mailchimp_transactional as MailchimpTransactional
# from mailchimp_transactional.api_client import ApiClientError
@ -123,9 +139,13 @@ class EmailCodeLoginSendEmailApi(Resource):
account = AccountService.get_user_through_email(args["email"])
if account is None:
raise InvalidEmailError()
if dify_config.ALLOW_REGISTER:
token = AccountService.send_email_code_login_email(email=args["email"])
else:
raise InvalidEmailError()
else:
token = AccountService.send_email_code_login_email(account=account)
token = AccountService.send_email_code_login_email(account)
return {"result": "success", "data": token}
@ -153,25 +173,17 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"])
account = AccountService.get_user_through_email(user_email)
if account is None:
# through environment variable, control whether to allow user to register and create workspace
if dify_config.ALLOW_REGISTER:
account = AccountService.create_account(
email=user_email, name=user_email, interface_language=languages[0]
)
else:
raise InvalidEmailError()
if dify_config.ALLOW_CREATE_WORKSPACE:
TenantService.create_owner_tenant_if_not_exist(account=account)
else:
raise NotAllowCreateWorkspaceError()
account = AccountService.create_user_through_env(
email=user_email, name=user_email, interface_language=languages[0]
)
else:
token = AccountService.login(account, ip_address=get_remote_ip(request))
token = AccountService.login(account, ip_address=get_remote_ip(request))
return {"result": "success", "data": token}
return {"result": "success", "data": token}
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordApi, "/reset-password")

View File

@ -189,15 +189,25 @@ def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Resp
class TokenManager:
@classmethod
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
old_token = cls._get_current_token_for_account(account.id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
def generate_token(
cls, token_type: str, account: Optional[Account] = None, email: Optional[str] = None,
additional_data: dict = None
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")
account_id = account.id if account else None
account_email = account.email if account else email
if account_id:
old_token = cls._get_current_token_for_account(account_id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4())
token_data = {"account_id": account.id, "email": account.email, "token_type": token_type}
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
if additional_data:
token_data.update(additional_data)
@ -206,7 +216,9 @@ class TokenManager:
expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
if account_id:
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token
@classmethod

View File

@ -158,6 +158,24 @@ class AccountService:
db.session.commit()
return account
@staticmethod
def create_user_through_env(
email: str, name: str, interface_language: str, password: Optional[str] = None
) -> Account:
"""create account"""
if dify_config.ALLOW_REGISTER:
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password
)
else:
raise Unauthorized("Register is not allowed.")
if dify_config.ALLOW_CREATE_WORKSPACE:
TenantService.create_owner_tenant_if_not_exist(account=account)
else:
raise Unauthorized("Create workspace is not allowed.")
return account
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
"""Link account integrate"""
@ -231,13 +249,20 @@ class AccountService:
return AccountService.load_user(account_id)
@classmethod
def send_reset_password_email(cls, account):
def send_reset_password_email(cls, account: Optional[Account] = None, email: Optional[str] = None):
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
token = TokenManager.generate_token(account, "reset_password")
send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token)
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
)
send_reset_password_mail_task.delay(
language=account.interface_language if account else languages[0],
to=account.email if account else email,
code=code,
)
cls.reset_password_rate_limiter.increment_rate_limit(account.email if account else email)
return token
@classmethod
@ -249,12 +274,14 @@ class AccountService:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def send_email_code_login_email(cls, account: Account):
def send_email_code_login_email(cls, account: Optional[Account] = None, email: Optional[str] = None):
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(account, "email_code_login", {"code": code})
token = TokenManager.generate_token(
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=account.interface_language,
to=account.email,
language=account.interface_language if account else languages[0],
to=account.email if account else email,
code=code,
)