mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 15:07:53 +08:00
forgot password email lower
This commit is contained in:
parent
5fde1bd603
commit
de60e56735
@ -76,6 +76,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||||
|
normalized_email = args.email.lower()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
if AccountService.is_email_send_ip_limit(ip_address):
|
if AccountService.is_email_send_ip_limit(ip_address):
|
||||||
@ -87,11 +88,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
language = "en-US"
|
language = "en-US"
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
account = _fetch_account_by_email(session, args.email)
|
||||||
|
|
||||||
token = AccountService.send_reset_password_email(
|
token = AccountService.send_reset_password_email(
|
||||||
account=account,
|
account=account,
|
||||||
email=args.email,
|
email=normalized_email,
|
||||||
language=language,
|
language=language,
|
||||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||||
)
|
)
|
||||||
@ -122,9 +123,9 @@ class ForgotPasswordCheckApi(Resource):
|
|||||||
def post(self):
|
def post(self):
|
||||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||||
|
|
||||||
user_email = args.email
|
user_email = args.email.lower()
|
||||||
|
|
||||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||||
if is_forgot_password_error_rate_limit:
|
if is_forgot_password_error_rate_limit:
|
||||||
raise EmailPasswordResetLimitError()
|
raise EmailPasswordResetLimitError()
|
||||||
|
|
||||||
@ -132,11 +133,14 @@ class ForgotPasswordCheckApi(Resource):
|
|||||||
if token_data is None:
|
if token_data is None:
|
||||||
raise InvalidTokenError()
|
raise InvalidTokenError()
|
||||||
|
|
||||||
if user_email != token_data.get("email"):
|
token_email = token_data.get("email")
|
||||||
|
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||||
|
|
||||||
|
if user_email != normalized_token_email:
|
||||||
raise InvalidEmailError()
|
raise InvalidEmailError()
|
||||||
|
|
||||||
if args.code != token_data.get("code"):
|
if args.code != token_data.get("code"):
|
||||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||||
raise EmailCodeError()
|
raise EmailCodeError()
|
||||||
|
|
||||||
# Verified, revoke the first token
|
# Verified, revoke the first token
|
||||||
@ -147,8 +151,8 @@ class ForgotPasswordCheckApi(Resource):
|
|||||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||||
)
|
)
|
||||||
|
|
||||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/forgot-password/resets")
|
@console_ns.route("/forgot-password/resets")
|
||||||
@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
password_hashed = hash_password(args.new_password, salt)
|
password_hashed = hash_password(args.new_password, salt)
|
||||||
|
|
||||||
email = reset_data.get("email", "")
|
email = reset_data.get("email", "")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
account = _fetch_account_by_email(session, email)
|
||||||
|
|
||||||
if account:
|
if account:
|
||||||
self._update_existing_account(account, password_hashed, salt, session)
|
self._update_existing_account(account, password_hashed, salt, session)
|
||||||
@ -213,3 +216,12 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||||
account.current_tenant = tenant
|
account.current_tenant = tenant
|
||||||
tenant_was_created.send(tenant)
|
tenant_was_created.send(tenant)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
|
||||||
|
"""Retrieve account by email with lowercase fallback for backward compatibility."""
|
||||||
|
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||||
|
if account or email == email.lower():
|
||||||
|
return account
|
||||||
|
|
||||||
|
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||||
|
|||||||
@ -0,0 +1,172 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from controllers.console.auth.forgot_password import (
|
||||||
|
ForgotPasswordCheckApi,
|
||||||
|
ForgotPasswordResetApi,
|
||||||
|
ForgotPasswordSendEmailApi,
|
||||||
|
_fetch_account_by_email,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app():
|
||||||
|
flask_app = Flask(__name__)
|
||||||
|
flask_app.config["TESTING"] = True
|
||||||
|
return flask_app
|
||||||
|
|
||||||
|
|
||||||
|
class TestForgotPasswordSendEmailApi:
|
||||||
|
@patch("controllers.console.auth.forgot_password.Session")
|
||||||
|
@patch("controllers.console.auth.forgot_password._fetch_account_by_email")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||||
|
@patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||||
|
def test_send_normalizes_email(
|
||||||
|
self,
|
||||||
|
mock_extract_ip,
|
||||||
|
mock_is_ip_limit,
|
||||||
|
mock_send_email,
|
||||||
|
mock_fetch_account,
|
||||||
|
mock_session_cls,
|
||||||
|
app,
|
||||||
|
):
|
||||||
|
mock_account = MagicMock()
|
||||||
|
mock_fetch_account.return_value = mock_account
|
||||||
|
mock_send_email.return_value = "token-123"
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||||
|
controller_features = SimpleNamespace(is_allow_register=True)
|
||||||
|
with patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch(
|
||||||
|
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
|
||||||
|
return_value=controller_features,
|
||||||
|
), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch(
|
||||||
|
"controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features
|
||||||
|
):
|
||||||
|
with app.test_request_context(
|
||||||
|
"/forgot-password",
|
||||||
|
method="POST",
|
||||||
|
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||||
|
):
|
||||||
|
response = ForgotPasswordSendEmailApi().post()
|
||||||
|
|
||||||
|
assert response == {"result": "success", "data": "token-123"}
|
||||||
|
mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com")
|
||||||
|
mock_send_email.assert_called_once_with(
|
||||||
|
account=mock_account,
|
||||||
|
email="user@example.com",
|
||||||
|
language="zh-Hans",
|
||||||
|
is_allow_register=True,
|
||||||
|
)
|
||||||
|
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||||
|
mock_extract_ip.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestForgotPasswordCheckApi:
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||||
|
def test_check_normalizes_email(
|
||||||
|
self,
|
||||||
|
mock_rate_limit_check,
|
||||||
|
mock_get_data,
|
||||||
|
mock_add_rate,
|
||||||
|
mock_revoke_token,
|
||||||
|
mock_generate_token,
|
||||||
|
mock_reset_rate,
|
||||||
|
app,
|
||||||
|
):
|
||||||
|
mock_rate_limit_check.return_value = False
|
||||||
|
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
|
||||||
|
mock_generate_token.return_value = (None, "new-token")
|
||||||
|
|
||||||
|
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||||
|
with patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch(
|
||||||
|
"controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features
|
||||||
|
):
|
||||||
|
with app.test_request_context(
|
||||||
|
"/forgot-password/validity",
|
||||||
|
method="POST",
|
||||||
|
json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"},
|
||||||
|
):
|
||||||
|
response = ForgotPasswordCheckApi().post()
|
||||||
|
|
||||||
|
assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"}
|
||||||
|
mock_rate_limit_check.assert_called_once_with("admin@example.com")
|
||||||
|
mock_generate_token.assert_called_once_with(
|
||||||
|
"admin@example.com",
|
||||||
|
code="4321",
|
||||||
|
additional_data={"phase": "reset"},
|
||||||
|
)
|
||||||
|
mock_reset_rate.assert_called_once_with("admin@example.com")
|
||||||
|
mock_add_rate.assert_not_called()
|
||||||
|
mock_revoke_token.assert_called_once_with("token-123")
|
||||||
|
|
||||||
|
|
||||||
|
class TestForgotPasswordResetApi:
|
||||||
|
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||||
|
@patch("controllers.console.auth.forgot_password.Session")
|
||||||
|
@patch("controllers.console.auth.forgot_password._fetch_account_by_email")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||||
|
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||||
|
def test_reset_fetches_account_with_original_email(
|
||||||
|
self,
|
||||||
|
mock_get_reset_data,
|
||||||
|
mock_revoke_token,
|
||||||
|
mock_fetch_account,
|
||||||
|
mock_session_cls,
|
||||||
|
mock_update_account,
|
||||||
|
app,
|
||||||
|
):
|
||||||
|
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||||
|
mock_account = MagicMock()
|
||||||
|
mock_fetch_account.return_value = mock_account
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||||
|
with patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch(
|
||||||
|
"controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")
|
||||||
|
), patch(
|
||||||
|
"controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features
|
||||||
|
):
|
||||||
|
with app.test_request_context(
|
||||||
|
"/forgot-password/resets",
|
||||||
|
method="POST",
|
||||||
|
json={
|
||||||
|
"token": "token-123",
|
||||||
|
"new_password": "ValidPass123!",
|
||||||
|
"password_confirm": "ValidPass123!",
|
||||||
|
},
|
||||||
|
):
|
||||||
|
response = ForgotPasswordResetApi().post()
|
||||||
|
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
mock_get_reset_data.assert_called_once_with("token-123")
|
||||||
|
mock_revoke_token.assert_called_once_with("token-123")
|
||||||
|
mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com")
|
||||||
|
mock_update_account.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_account_by_email_fallback():
|
||||||
|
mock_session = MagicMock()
|
||||||
|
first_query = MagicMock()
|
||||||
|
first_query.scalar_one_or_none.return_value = None
|
||||||
|
expected_account = MagicMock()
|
||||||
|
second_query = MagicMock()
|
||||||
|
second_query.scalar_one_or_none.return_value = expected_account
|
||||||
|
mock_session.execute.side_effect = [first_query, second_query]
|
||||||
|
|
||||||
|
account = _fetch_account_by_email(mock_session, "Mixed@Test.com")
|
||||||
|
|
||||||
|
assert account is expected_account
|
||||||
|
assert mock_session.execute.call_count == 2
|
||||||
Loading…
Reference in New Issue
Block a user