refactor to get_account_by_email_with_case_fallback

This commit is contained in:
hjlarry 2025-12-20 10:38:21 +08:00
parent 0d1cfc1969
commit 6a4d6c7bf2
9 changed files with 58 additions and 87 deletions

View File

@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -75,7 +74,7 @@ class EmailRegisterSendEmailApi(Resource):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = _fetch_account_by_email(session, args.email)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@ -147,7 +146,7 @@ class EmailRegisterResetApi(Resource):
normalized_email = email.lower()
with Session(db.engine) as session:
account = _fetch_account_by_email(session, email)
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
@ -174,16 +173,3 @@ class EmailRegisterResetApi(Resource):
raise AccountInFreezeError()
return account
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
"""
Retrieve account by email with lowercase fallback for backward compatibility.
To prevent user register with Uppercase email success get a lowercase email account,
but already exist the Uppercase email account.
"""
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()

View File

@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@ -88,7 +86,7 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = _fetch_account_by_email(session, args.email)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
@ -192,7 +190,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = _fetch_account_by_email(session, email)
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)
@ -216,12 +214,3 @@ class ForgotPasswordResetApi(Resource):
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
"""Retrieve account by email with lowercase fallback for backward compatibility."""
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()

View File

@ -3,7 +3,6 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@ -175,7 +174,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
account = _fetch_account_by_email(session, user_info.email)
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
@ -229,10 +228,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
AccountService.link_account_integrate(provider, user_info.id, account)
return account
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()

View File

@ -39,7 +39,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -551,7 +551,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email
else:
with Session(db.engine) as session:
account = _fetch_account_by_email(session, args.email)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
email_for_sending = account.email
@ -661,10 +661,3 @@ class CheckEmailUnique(Resource):
if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()

View File

@ -8,7 +8,7 @@ from hashlib import sha256
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@ -748,6 +748,21 @@ class AccountService:
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
@staticmethod
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
"""
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
This keeps backward compatibility for older records that stored uppercase emails while the
rest of the system gradually normalizes new inputs.
"""
query_session = session or db.session
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
@classmethod
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")

View File

@ -8,8 +8,8 @@ from controllers.console.auth.email_register import (
EmailRegisterCheckApi,
EmailRegisterResetApi,
EmailRegisterSendEmailApi,
_fetch_account_by_email,
)
from services.account_service import AccountService
@pytest.fixture
@ -21,6 +21,7 @@ def app():
class TestEmailRegisterSendEmailApi:
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
@patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False)
@ -31,6 +32,7 @@ class TestEmailRegisterSendEmailApi:
mock_is_email_send_ip_limit,
mock_is_freeze,
mock_send_mail,
mock_get_account,
mock_session_cls,
app,
):
@ -38,14 +40,9 @@ class TestEmailRegisterSendEmailApi:
mock_is_freeze.return_value = False
mock_account = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = mock_account
mock_session = MagicMock()
mock_session.execute.side_effect = [first_query, second_query]
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = mock_account
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch(
@ -65,7 +62,7 @@ class TestEmailRegisterSendEmailApi:
assert response == {"result": "success", "data": "token-123"}
mock_is_freeze.assert_called_once_with("invitee@example.com")
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
assert mock_session.execute.call_count == 2
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
mock_extract_ip.assert_called_once()
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
@ -119,6 +116,7 @@ class TestEmailRegisterResetApi:
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.Session")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
@ -127,6 +125,7 @@ class TestEmailRegisterResetApi:
mock_extract_ip,
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_session_cls,
mock_create_account,
mock_login,
@ -139,14 +138,9 @@ class TestEmailRegisterResetApi:
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
second_query = MagicMock()
second_query.scalar_one_or_none.return_value = None
mock_session = MagicMock()
mock_session.execute.side_effect = [first_query, second_query]
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch(
@ -166,10 +160,10 @@ class TestEmailRegisterResetApi:
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
assert mock_session.execute.call_count == 2
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
def test_fetch_account_by_email_fallback():
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
@ -178,7 +172,7 @@ def test_fetch_account_by_email_fallback():
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
account = _fetch_account_by_email(mock_session, "Case@Test.com")
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert account is expected_account
assert mock_session.execute.call_count == 2

View File

@ -8,8 +8,8 @@ from controllers.console.auth.forgot_password import (
ForgotPasswordCheckApi,
ForgotPasswordResetApi,
ForgotPasswordSendEmailApi,
_fetch_account_by_email,
)
from services.account_service import AccountService
@pytest.fixture
@ -21,7 +21,7 @@ def app():
class TestForgotPasswordSendEmailApi:
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password._fetch_account_by_email")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1")
@ -30,12 +30,12 @@ class TestForgotPasswordSendEmailApi:
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_fetch_account,
mock_get_account,
mock_session_cls,
app,
):
mock_account = MagicMock()
mock_fetch_account.return_value = mock_account
mock_get_account.return_value = mock_account
mock_send_email.return_value = "token-123"
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
@ -56,7 +56,7 @@ class TestForgotPasswordSendEmailApi:
response = ForgotPasswordSendEmailApi().post()
assert response == {"result": "success", "data": "token-123"}
mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com")
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=mock_account,
email="user@example.com",
@ -114,21 +114,21 @@ class TestForgotPasswordCheckApi:
class TestForgotPasswordResetApi:
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
@patch("controllers.console.auth.forgot_password.Session")
@patch("controllers.console.auth.forgot_password._fetch_account_by_email")
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
def test_reset_fetches_account_with_original_email(
self,
mock_get_reset_data,
mock_revoke_token,
mock_fetch_account,
mock_get_account,
mock_session_cls,
mock_update_account,
app,
):
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
mock_account = MagicMock()
mock_fetch_account.return_value = mock_account
mock_get_account.return_value = mock_account
mock_session = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
@ -153,11 +153,11 @@ class TestForgotPasswordResetApi:
assert response == {"result": "success"}
mock_get_reset_data.assert_called_once_with("token-123")
mock_revoke_token.assert_called_once_with("token-123")
mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com")
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
mock_update_account.assert_called_once()
def test_fetch_account_by_email_fallback():
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
mock_session = MagicMock()
first_query = MagicMock()
first_query.scalar_one_or_none.return_value = None
@ -166,7 +166,7 @@ def test_fetch_account_by_email_fallback():
second_query.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_query, second_query]
account = _fetch_account_by_email(mock_session, "Mixed@Test.com")
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
assert account is expected_account
assert mock_session.execute.call_count == 2

View File

@ -6,13 +6,13 @@ from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
OAuthLogin,
_fetch_account_by_email,
_generate_account,
_get_account_by_openid_or_email,
get_oauth_providers,
)
from libs.oauth import OAuthUserInfo
from models.account import AccountStatus
from services.account_service import AccountService
from services.errors.account import AccountRegisterError
@ -424,12 +424,12 @@ class TestAccountGeneration:
account.name = "Test User"
return account
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.oauth.Session")
@patch("controllers.console.auth.oauth.select")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.db")
def test_should_get_account_by_openid_or_email(
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
):
# Mock db.engine for Session creation
mock_db.engine = MagicMock()
@ -439,17 +439,19 @@ class TestAccountGeneration:
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
mock_get_account.assert_not_called()
# Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
def test_fetch_account_by_email_fallback(self):
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
mock_session = MagicMock()
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
@ -458,7 +460,7 @@ class TestAccountGeneration:
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = _fetch_account_by_email(mock_session, "Case@Test.com")
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert result == expected_account
assert mock_session.execute.call_count == 2

View File

@ -10,9 +10,9 @@ from controllers.console.workspace.account import (
ChangeEmailResetApi,
ChangeEmailSendEmailApi,
CheckEmailUnique,
_fetch_account_by_email,
)
from models import Account
from services.account_service import AccountService
@pytest.fixture
@ -223,7 +223,7 @@ class TestCheckEmailUnique:
mock_check_unique.assert_called_once_with("case@test.com")
def test_fetch_account_by_email_fallback():
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
session = MagicMock()
first = MagicMock()
first.scalar_one_or_none.return_value = None
@ -232,7 +232,7 @@ def test_fetch_account_by_email_fallback():
second.scalar_one_or_none.return_value = expected_account
session.execute.side_effect = [first, second]
result = _fetch_account_by_email(session, "Mixed@Test.com")
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
assert result is expected_account
assert session.execute.call_count == 2