From de60e5673570bf0c331ca085ca2cfff0e506e2d5 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Fri, 19 Dec 2025 14:04:40 +0800 Subject: [PATCH] forgot password email lower --- .../console/auth/forgot_password.py | 32 +++- .../console/auth/test_forgot_password.py | 172 ++++++++++++++++++ 2 files changed, 194 insertions(+), 10 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/auth/test_forgot_password.py diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 661f591182..1875312a13 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -76,6 +76,7 @@ class ForgotPasswordSendEmailApi(Resource): @email_password_login_enabled def post(self): args = ForgotPasswordSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -87,11 +88,11 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() + account = _fetch_account_by_email(session, args.email) token = AccountService.send_reset_password_email( account=account, - email=args.email, + email=normalized_email, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, ) @@ -122,9 +123,9 @@ class ForgotPasswordCheckApi(Resource): def post(self): args = ForgotPasswordCheckPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -132,11 +133,14 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args.email) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -147,8 +151,8 @@ class ForgotPasswordCheckApi(Resource): user_email, code=args.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/forgot-password/resets") @@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = _fetch_account_by_email(session, email) if account: self._update_existing_account(account, password_hashed, salt, session) @@ -213,3 +216,12 @@ class ForgotPasswordResetApi(Resource): TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) + + +def _fetch_account_by_email(session: Session, email: str) -> Account | None: + """Retrieve account by email with lowercase fallback for backward compatibility.""" + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account + + return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py new file mode 100644 index 0000000000..bf01165bbf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py @@ -0,0 +1,172 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, + _fetch_account_by_email, +) + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password._fetch_account_by_email") + @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1") + def test_send_normalizes_email( + self, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_fetch_account, + mock_session_cls, + app, + ): + mock_account = MagicMock() + mock_fetch_account.return_value = mock_account + mock_send_email.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + controller_features = SimpleNamespace(is_allow_register=True) + with patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch( + "controllers.console.auth.forgot_password.FeatureService.get_system_features", + return_value=controller_features, + ), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch( + "controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features + ): + with app.test_request_context( + "/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com") + mock_send_email.assert_called_once_with( + account=mock_account, + email="user@example.com", + language="zh-Hans", + is_allow_register=True, + ) + mock_is_ip_limit.assert_called_once_with("127.0.0.1") + mock_extract_ip.assert_called_once() + + +class TestForgotPasswordCheckApi: + @patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_check_normalizes_email( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch( + "controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features + ): + with app.test_request_context( + "/forgot-password/validity", + method="POST", + json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("admin@example.com") + mock_generate_token.assert_called_once_with( + "admin@example.com", + code="4321", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("admin@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + + +class TestForgotPasswordResetApi: + @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password._fetch_account_by_email") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + def test_reset_fetches_account_with_original_email( + self, + mock_get_reset_data, + mock_revoke_token, + mock_fetch_account, + mock_session_cls, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} + mock_account = MagicMock() + mock_fetch_account.return_value = mock_account + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch( + "controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD") + ), patch( + "controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features + ): + with app.test_request_context( + "/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_reset_data.assert_called_once_with("token-123") + mock_revoke_token.assert_called_once_with("token-123") + mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com") + mock_update_account.assert_called_once() + + +def test_fetch_account_by_email_fallback(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = _fetch_account_by_email(mock_session, "Mixed@Test.com") + + assert account is expected_account + assert mock_session.execute.call_count == 2