diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index b9e391e049..958ae65802 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -3,7 +3,6 @@ import secrets from flask import request from flask_restx import Resource, reqparse -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.console.auth.error import ( @@ -20,7 +19,6 @@ from controllers.web import web_ns from extensions.ext_database import db from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account from services.account_service import AccountService @@ -47,6 +45,9 @@ class ForgotPasswordSendEmailApi(Resource): ) args = parser.parse_args() + request_email = args["email"] + normalized_email = request_email.lower() + ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() @@ -57,12 +58,12 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) token = None if account is None: raise AuthenticationFailedError() else: - token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language) return {"result": "success", "data": token} @@ -86,9 +87,9 @@ class ForgotPasswordCheckApi(Resource): ) args = parser.parse_args() - user_email = args["email"] + user_email = args["email"].lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -96,11 +97,13 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if user_email != normalized_token_email: raise InvalidEmailError() if args["code"] != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args["email"]) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -111,8 +114,8 @@ class ForgotPasswordCheckApi(Resource): user_email, code=args["code"], additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args["email"]) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @web_ns.route("/forgot-password/resets") @@ -161,7 +164,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) diff --git a/api/tests/unit_tests/controllers/web/test_forgot_password.py b/api/tests/unit_tests/controllers/web/test_forgot_password.py new file mode 100644 index 0000000000..68632b7094 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_forgot_password.py @@ -0,0 +1,146 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, +) + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture(autouse=True) +def _patch_wraps(): + wraps_features = SimpleNamespace(enable_email_password_login=True) + dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD") + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config", dify_settings), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + yield + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") + @patch("controllers.web.forgot_password.Session") + def test_should_normalize_email_before_sending( + self, + mock_session_cls, + mock_extract_ip, + mock_rate_limit, + mock_get_account, + mock_send_mail, + app, + ): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_send_mail.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") + mock_extract_ip.assert_called_once() + mock_rate_limit.assert_called_once_with("127.0.0.1") + + +class TestForgotPasswordCheckApi: + @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_should_normalize_email_for_validity_checks( + self, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} + mock_generate_token.return_value = (None, "new-token") + + with app.test_request_context( + "/web/forgot-password/validity", + method="POST", + json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_is_rate_limit.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + mock_generate_token.assert_called_once_with( + "user@example.com", + code="1234", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("user@example.com") + + +class TestForgotPasswordResetApi: + @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + def test_should_fetch_account_with_fallback( + self, + mock_get_reset_data, + mock_revoke_token, + mock_session_cls, + mock_get_account, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_update_account.assert_called_once() + mock_revoke_token.assert_called_once_with("token-123")