diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index fa082c735d..db1846f9a8 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -62,6 +62,7 @@ class EmailRegisterSendEmailApi(Resource): @email_register_enabled def post(self): args = EmailRegisterSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -70,13 +71,12 @@ class EmailRegisterSendEmailApi(Resource): if args.language in languages: language = args.language - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() - token = None - token = AccountService.send_email_register_email(email=args.email, account=account, language=language) + account = _fetch_account_by_email(session, args.email) + token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -88,9 +88,9 @@ class EmailRegisterCheckApi(Resource): def post(self): args = EmailRegisterValidityPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email) + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email) if is_email_register_error_rate_limit: raise EmailRegisterLimitError() @@ -98,11 +98,14 @@ class EmailRegisterCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_email_register_error_rate_limit(args.email) + AccountService.add_email_register_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -113,8 +116,8 @@ class EmailRegisterCheckApi(Resource): user_email, code=args.code, additional_data={"phase": "register"} ) - AccountService.reset_email_register_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_email_register_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/email-register") @@ -141,22 +144,23 @@ class EmailRegisterResetApi(Resource): AccountService.revoke_email_register_token(args.token) email = register_data.get("email", "") + normalized_email = email.lower() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = _fetch_account_by_email(session, email) if account: raise EmailAlreadyInUseError() else: - account = self._create_new_account(email, args.password_confirm) + account = self._create_new_account(normalized_email, args.password_confirm) if not account: raise AccountNotFoundError() token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} - def _create_new_account(self, email, password) -> Account | None: + def _create_new_account(self, email: str, password: str) -> Account | None: # Create new account if allowed account = None try: @@ -170,3 +174,16 @@ class EmailRegisterResetApi(Resource): raise AccountInFreezeError() return account + + +def _fetch_account_by_email(session: Session, email: str) -> Account | None: + """ + Retrieve account by email with lowercase fallback for backward compatibility. + To prevent user register with Uppercase email success get a lowercase email account, + but already exist the Uppercase email account. + """ + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account + + return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/unit_tests/controllers/console/auth/test_email_register.py new file mode 100644 index 0000000000..330df9595e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.email_register import ( + EmailRegisterCheckApi, + EmailRegisterResetApi, + EmailRegisterSendEmailApi, + _fetch_account_by_email, +) + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestEmailRegisterSendEmailApi: + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") + @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") + @patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_send_email_normalizes_and_falls_back( + self, + mock_extract_ip, + mock_is_email_send_ip_limit, + mock_is_freeze, + mock_send_mail, + mock_session_cls, + app, + ): + mock_send_mail.return_value = "token-123" + mock_is_freeze.return_value = False + mock_account = MagicMock() + + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = mock_account + + mock_session = MagicMock() + mock_session.execute.side_effect = [first_query, second_query] + mock_session_cls.return_value.__enter__.return_value = mock_session + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch( + "controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True) + ), patch( + "controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD") + ), patch( + "controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags + ): + with app.test_request_context( + "/email-register/send-email", + method="POST", + json={"email": "Invitee@Example.com", "language": "en-US"}, + ): + response = EmailRegisterSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_is_freeze.assert_called_once_with("invitee@example.com") + mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") + assert mock_session.execute.call_count == 2 + mock_extract_ip.assert_called_once() + mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") + + +class TestEmailRegisterCheckApi: + @patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.generate_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit") + def test_validity_normalizes_email_before_checks( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch( + "controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD") + ), patch( + "controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags + ): + with app.test_request_context( + "/email-register/validity", + method="POST", + json={"email": "User@Example.com", "code": "4321", "token": "token-123"}, + ): + response = EmailRegisterCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("user@example.com") + mock_generate_token.assert_called_once_with( + "user@example.com", code="4321", additional_data={"phase": "register"} + ) + mock_reset_rate.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke.assert_called_once_with("token-123") + + +class TestEmailRegisterResetApi: + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_creates_account_with_normalized_email( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_session_cls, + mock_create_account, + mock_login, + mock_reset_login_rate, + app, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = None + + mock_session = MagicMock() + mock_session.execute.side_effect = [first_query, second_query] + mock_session_cls.return_value.__enter__.return_value = mock_session + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch( + "controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD") + ), patch( + "controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags + ): + with app.test_request_context( + "/email-register", + method="POST", + json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"}, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!") + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + assert mock_session.execute.call_count == 2 + + +def test_fetch_account_by_email_fallback(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = _fetch_account_by_email(mock_session, "Case@Test.com") + + assert account is expected_account + assert mock_session.execute.call_count == 2