mirror of https://github.com/langgenius/dify.git
email reg lower
This commit is contained in:
parent
8e0d5c8cd2
commit
5fde1bd603
|
|
@ -62,6 +62,7 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -70,13 +71,12 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
account = _fetch_account_by_email(session, args.email)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
|
|
@ -88,9 +88,9 @@ class EmailRegisterCheckApi(Resource):
|
|||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
|
|
@ -98,11 +98,14 @@ class EmailRegisterCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
AccountService.add_email_register_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -113,8 +116,8 @@ class EmailRegisterCheckApi(Resource):
|
|||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_email_register_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
|
|
@ -141,22 +144,23 @@ class EmailRegisterResetApi(Resource):
|
|||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = _fetch_account_by_email(session, email)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email, password) -> Account | None:
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
try:
|
||||
|
|
@ -170,3 +174,16 @@ class EmailRegisterResetApi(Resource):
|
|||
raise AccountInFreezeError()
|
||||
|
||||
return account
|
||||
|
||||
|
||||
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
|
||||
"""
|
||||
Retrieve account by email with lowercase fallback for backward compatibility.
|
||||
To prevent user register with Uppercase email success get a lowercase email account,
|
||||
but already exist the Uppercase email account.
|
||||
"""
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,184 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
EmailRegisterResetApi,
|
||||
EmailRegisterSendEmailApi,
|
||||
_fetch_account_by_email,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class TestEmailRegisterSendEmailApi:
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
|
||||
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_send_email_normalizes_and_falls_back(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_email_send_ip_limit,
|
||||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
mock_account = MagicMock()
|
||||
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = mock_account
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch(
|
||||
"controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)
|
||||
), patch(
|
||||
"controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")
|
||||
), patch(
|
||||
"controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/send-email",
|
||||
method="POST",
|
||||
json={"email": "Invitee@Example.com", "language": "en-US"},
|
||||
):
|
||||
response = EmailRegisterSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_is_freeze.assert_called_once_with("invitee@example.com")
|
||||
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
|
||||
assert mock_session.execute.call_count == 2
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
||||
class TestEmailRegisterCheckApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.generate_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit")
|
||||
def test_validity_normalizes_email_before_checks(
|
||||
self,
|
||||
mock_rate_limit_check,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch(
|
||||
"controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")
|
||||
), patch(
|
||||
"controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "4321", "token": "token-123"},
|
||||
):
|
||||
response = EmailRegisterCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_rate_limit_check.assert_called_once_with("user@example.com")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="4321", additional_data={"phase": "register"}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
class TestEmailRegisterResetApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_creates_account_with_normalized_email(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = None
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch(
|
||||
"controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")
|
||||
), patch(
|
||||
"controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
|
||||
def test_fetch_account_by_email_fallback():
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
|
||||
account = _fetch_account_by_email(mock_session, "Case@Test.com")
|
||||
|
||||
assert account is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
Loading…
Reference in New Issue