mirror of https://github.com/langgenius/dify.git
feat: update register logic
This commit is contained in:
parent
ccc0ec8178
commit
da684ebfaa
|
|
@ -35,9 +35,3 @@ class EmailLoginCodeError(BaseHTTPException):
|
|||
error_code = "email_login_code_error"
|
||||
description = "Email login code is invalid or expired."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotAllowCreateWorkspaceError(BaseHTTPException):
|
||||
error_code = "workspace_not_found"
|
||||
description = "Workspace not found."
|
||||
code = 400
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from controllers.console.auth.error import (
|
|||
EmailLoginCodeError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
NotAllowCreateWorkspaceError,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email, get_remote_ip
|
||||
|
|
@ -33,8 +32,6 @@ class LoginApi(Resource):
|
|||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: Verify the recaptcha
|
||||
|
||||
try:
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError as e:
|
||||
|
|
@ -63,12 +60,31 @@ class LogoutApi(Resource):
|
|||
return {"result": "success"}
|
||||
|
||||
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
if dify_config.ALLOW_REGISTER:
|
||||
token = AccountService.send_reset_password_email(email=args["email"])
|
||||
else:
|
||||
raise InvalidEmailError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class ResetPasswordApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
# parser = reqparse.RequestParser()
|
||||
# parser.add_argument('email', type=email, required=True, location='json')
|
||||
# args = parser.parse_args()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# import mailchimp_transactional as MailchimpTransactional
|
||||
# from mailchimp_transactional.api_client import ApiClientError
|
||||
|
|
@ -123,9 +139,13 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
raise InvalidEmailError()
|
||||
if dify_config.ALLOW_REGISTER:
|
||||
token = AccountService.send_email_code_login_email(email=args["email"])
|
||||
else:
|
||||
raise InvalidEmailError()
|
||||
else:
|
||||
token = AccountService.send_email_code_login_email(account=account)
|
||||
|
||||
token = AccountService.send_email_code_login_email(account)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
|
|
@ -153,25 +173,17 @@ class EmailCodeLoginApi(Resource):
|
|||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
if account is None:
|
||||
# through environment variable, control whether to allow user to register and create workspace
|
||||
if dify_config.ALLOW_REGISTER:
|
||||
account = AccountService.create_account(
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
else:
|
||||
raise InvalidEmailError()
|
||||
if dify_config.ALLOW_CREATE_WORKSPACE:
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
else:
|
||||
raise NotAllowCreateWorkspaceError()
|
||||
account = AccountService.create_user_through_env(
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
|
||||
else:
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
api.add_resource(ResetPasswordApi, "/reset-password")
|
||||
|
|
|
|||
|
|
@ -189,15 +189,25 @@ def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Resp
|
|||
|
||||
class TokenManager:
|
||||
@classmethod
|
||||
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
|
||||
old_token = cls._get_current_token_for_account(account.id, token_type)
|
||||
if old_token:
|
||||
if isinstance(old_token, bytes):
|
||||
old_token = old_token.decode("utf-8")
|
||||
cls.revoke_token(old_token, token_type)
|
||||
def generate_token(
|
||||
cls, token_type: str, account: Optional[Account] = None, email: Optional[str] = None,
|
||||
additional_data: dict = None
|
||||
) -> str:
|
||||
if account is None and email is None:
|
||||
raise ValueError("Account or email must be provided")
|
||||
|
||||
account_id = account.id if account else None
|
||||
account_email = account.email if account else email
|
||||
|
||||
if account_id:
|
||||
old_token = cls._get_current_token_for_account(account_id, token_type)
|
||||
if old_token:
|
||||
if isinstance(old_token, bytes):
|
||||
old_token = old_token.decode("utf-8")
|
||||
cls.revoke_token(old_token, token_type)
|
||||
|
||||
token = str(uuid.uuid4())
|
||||
token_data = {"account_id": account.id, "email": account.email, "token_type": token_type}
|
||||
token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
|
||||
if additional_data:
|
||||
token_data.update(additional_data)
|
||||
|
||||
|
|
@ -206,7 +216,9 @@ class TokenManager:
|
|||
expiry_time = int(expiry_hours * 60 * 60)
|
||||
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
|
||||
|
||||
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
|
||||
if account_id:
|
||||
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
|
||||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -158,6 +158,24 @@ class AccountService:
|
|||
db.session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_user_through_env(
|
||||
email: str, name: str, interface_language: str, password: Optional[str] = None
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
if dify_config.ALLOW_REGISTER:
|
||||
account = AccountService.create_account(
|
||||
email=email, name=name, interface_language=interface_language, password=password
|
||||
)
|
||||
else:
|
||||
raise Unauthorized("Register is not allowed.")
|
||||
if dify_config.ALLOW_CREATE_WORKSPACE:
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
else:
|
||||
raise Unauthorized("Create workspace is not allowed.")
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
|
||||
"""Link account integrate"""
|
||||
|
|
@ -231,13 +249,20 @@ class AccountService:
|
|||
return AccountService.load_user(account_id)
|
||||
|
||||
@classmethod
|
||||
def send_reset_password_email(cls, account):
|
||||
def send_reset_password_email(cls, account: Optional[Account] = None, email: Optional[str] = None):
|
||||
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
|
||||
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
|
||||
|
||||
token = TokenManager.generate_token(account, "reset_password")
|
||||
send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token)
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="reset_password", additional_data={"code": code}
|
||||
)
|
||||
send_reset_password_mail_task.delay(
|
||||
language=account.interface_language if account else languages[0],
|
||||
to=account.email if account else email,
|
||||
code=code,
|
||||
)
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account.email if account else email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
|
|
@ -249,12 +274,14 @@ class AccountService:
|
|||
return TokenManager.get_token_data(token, "reset_password")
|
||||
|
||||
@classmethod
|
||||
def send_email_code_login_email(cls, account: Account):
|
||||
def send_email_code_login_email(cls, account: Optional[Account] = None, email: Optional[str] = None):
|
||||
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||
token = TokenManager.generate_token(account, "email_code_login", {"code": code})
|
||||
token = TokenManager.generate_token(
|
||||
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
|
||||
)
|
||||
send_email_code_login_mail_task.delay(
|
||||
language=account.interface_language,
|
||||
to=account.email,
|
||||
language=account.interface_language if account else languages[0],
|
||||
to=account.email if account else email,
|
||||
code=code,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue