diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index db1846f9a8..c2a95ddad2 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -75,7 +74,7 @@ class EmailRegisterSendEmailApi(Resource): raise AccountInFreezeError() with Session(db.engine) as session: - account = _fetch_account_by_email(session, args.email) + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -147,7 +146,7 @@ class EmailRegisterResetApi(Resource): normalized_email = email.lower() with Session(db.engine) as session: - account = _fetch_account_by_email(session, email) + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: raise EmailAlreadyInUseError() @@ -174,16 +173,3 @@ class EmailRegisterResetApi(Resource): raise AccountInFreezeError() return account - - -def _fetch_account_by_email(session: Session, email: str) -> Account | None: - """ - Retrieve account by email with lowercase fallback for backward compatibility. - To prevent user register with Uppercase email success get a lowercase email account, - but already exist the Uppercase email account. - """ - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account - - return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 1875312a13..2675c5ed03 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.console import console_ns @@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account from services.account_service import AccountService, TenantService from services.feature_service import FeatureService @@ -88,7 +86,7 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = _fetch_account_by_email(session, args.email) + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( account=account, @@ -192,7 +190,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") with Session(db.engine) as session: - account = _fetch_account_by_email(session, email) + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) @@ -216,12 +214,3 @@ class ForgotPasswordResetApi(Resource): TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) - - -def _fetch_account_by_email(session: Session, email: str) -> Account | None: - """Retrieve account by email with lowercase fallback for backward compatibility.""" - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account - - return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 3c948f068d..2959fc0cbd 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,7 +3,6 @@ import logging import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -175,7 +174,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> if not account: with Session(db.engine) as session: - account = _fetch_account_by_email(session, user_info.email) + account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account @@ -229,10 +228,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): AccountService.link_account_integrate(provider, user_info.id, account) return account - - -def _fetch_account_by_email(session: Session, email: str) -> Account | None: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account - return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index bb7d274f57..b0da1a806f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -39,7 +39,7 @@ from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required -from models import Account, AccountIntegrate, InvitationCode +from models import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -551,7 +551,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: with Session(db.engine) as session: - account = _fetch_account_by_email(session, args.email) + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) if account is None: raise AccountNotFound() email_for_sending = account.email @@ -661,10 +661,3 @@ class CheckEmailUnique(Resource): if not AccountService.check_email_unique(normalized_email): raise EmailAlreadyInUseError() return {"result": "success"} - - -def _fetch_account_by_email(session: Session, email: str) -> Account | None: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account - return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() diff --git a/api/services/account_service.py b/api/services/account_service.py index 5a549dc318..8cfea4942e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from hashlib import sha256 from typing import Any, cast from pydantic import BaseModel -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -748,6 +748,21 @@ class AccountService: cls.email_code_login_rate_limiter.increment_rate_limit(email) return token + @staticmethod + def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + """ + Retrieve an account by email and fall back to the lowercase email if the original lookup fails. + + This keeps backward compatibility for older records that stored uppercase emails while the + rest of the system gradually normalizes new inputs. + """ + query_session = session or db.session + account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account + + return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/unit_tests/controllers/console/auth/test_email_register.py index 330df9595e..06b4017e67 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register.py @@ -8,8 +8,8 @@ from controllers.console.auth.email_register import ( EmailRegisterCheckApi, EmailRegisterResetApi, EmailRegisterSendEmailApi, - _fetch_account_by_email, ) +from services.account_service import AccountService @pytest.fixture @@ -21,6 +21,7 @@ def app(): class TestEmailRegisterSendEmailApi: @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") @patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False) @@ -31,6 +32,7 @@ class TestEmailRegisterSendEmailApi: mock_is_email_send_ip_limit, mock_is_freeze, mock_send_mail, + mock_get_account, mock_session_cls, app, ): @@ -38,14 +40,9 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze.return_value = False mock_account = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = mock_account - mock_session = MagicMock() - mock_session.execute.side_effect = [first_query, second_query] mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = mock_account feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch( @@ -65,7 +62,7 @@ class TestEmailRegisterSendEmailApi: assert response == {"result": "success", "data": "token-123"} mock_is_freeze.assert_called_once_with("invitee@example.com") mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") - assert mock_session.execute.call_count == 2 + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) mock_extract_ip.assert_called_once() mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") @@ -119,6 +116,7 @@ class TestEmailRegisterResetApi: @patch("controllers.console.auth.email_register.AccountService.login") @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") @@ -127,6 +125,7 @@ class TestEmailRegisterResetApi: mock_extract_ip, mock_get_data, mock_revoke_token, + mock_get_account, mock_session_cls, mock_create_account, mock_login, @@ -139,14 +138,9 @@ class TestEmailRegisterResetApi: token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} mock_login.return_value = token_pair - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = None - mock_session = MagicMock() - mock_session.execute.side_effect = [first_query, second_query] mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = None feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch( @@ -166,10 +160,10 @@ class TestEmailRegisterResetApi: mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() - assert mock_session.execute.call_count == 2 + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) -def test_fetch_account_by_email_fallback(): +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): mock_session = MagicMock() first_query = MagicMock() first_query.scalar_one_or_none.return_value = None @@ -178,7 +172,7 @@ def test_fetch_account_by_email_fallback(): second_query.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_query, second_query] - account = _fetch_account_by_email(mock_session, "Case@Test.com") + account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) assert account is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py index bf01165bbf..d512b8ad71 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py @@ -8,8 +8,8 @@ from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, ForgotPasswordResetApi, ForgotPasswordSendEmailApi, - _fetch_account_by_email, ) +from services.account_service import AccountService @pytest.fixture @@ -21,7 +21,7 @@ def app(): class TestForgotPasswordSendEmailApi: @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password._fetch_account_by_email") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1") @@ -30,12 +30,12 @@ class TestForgotPasswordSendEmailApi: mock_extract_ip, mock_is_ip_limit, mock_send_email, - mock_fetch_account, + mock_get_account, mock_session_cls, app, ): mock_account = MagicMock() - mock_fetch_account.return_value = mock_account + mock_get_account.return_value = mock_account mock_send_email.return_value = "token-123" mock_session = MagicMock() mock_session_cls.return_value.__enter__.return_value = mock_session @@ -56,7 +56,7 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com") + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_send_email.assert_called_once_with( account=mock_account, email="user@example.com", @@ -114,21 +114,21 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password._fetch_account_by_email") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, - mock_fetch_account, + mock_get_account, mock_session_cls, mock_update_account, app, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() - mock_fetch_account.return_value = mock_account + mock_get_account.return_value = mock_account mock_session = MagicMock() mock_session_cls.return_value.__enter__.return_value = mock_session @@ -153,11 +153,11 @@ class TestForgotPasswordResetApi: assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123") - mock_fetch_account.assert_called_once_with(mock_session, "User@Example.com") + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_update_account.assert_called_once() -def test_fetch_account_by_email_fallback(): +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): mock_session = MagicMock() first_query = MagicMock() first_query.scalar_one_or_none.return_value = None @@ -166,7 +166,7 @@ def test_fetch_account_by_email_fallback(): second_query.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_query, second_query] - account = _fetch_account_by_email(mock_session, "Mixed@Test.com") + account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) assert account is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 8cd3e69c53..3ce79509bd 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -6,13 +6,13 @@ from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, OAuthLogin, - _fetch_account_by_email, _generate_account, _get_account_by_openid_or_email, get_oauth_providers, ) from libs.oauth import OAuthUserInfo from models.account import AccountStatus +from services.account_service import AccountService from services.errors.account import AccountRegisterError @@ -424,12 +424,12 @@ class TestAccountGeneration: account.name = "Test User" return account - @patch("controllers.console.auth.oauth.db") - @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.oauth.Session") - @patch("controllers.console.auth.oauth.select") + @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account + self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account ): # Mock db.engine for Session creation mock_db.engine = MagicMock() @@ -439,17 +439,19 @@ class TestAccountGeneration: result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account mock_account_model.get_by_openid.assert_called_once_with("github", "123") + mock_get_account.assert_not_called() # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account + mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) - def test_fetch_account_by_email_fallback(self): + def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): mock_session = MagicMock() first_result = MagicMock() first_result.scalar_one_or_none.return_value = None @@ -458,7 +460,7 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = _fetch_account_by_email(mock_session, "Case@Test.com") + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) assert result == expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 633fe0a10c..3cae3be52f 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -10,9 +10,9 @@ from controllers.console.workspace.account import ( ChangeEmailResetApi, ChangeEmailSendEmailApi, CheckEmailUnique, - _fetch_account_by_email, ) from models import Account +from services.account_service import AccountService @pytest.fixture @@ -223,7 +223,7 @@ class TestCheckEmailUnique: mock_check_unique.assert_called_once_with("case@test.com") -def test_fetch_account_by_email_fallback(): +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None @@ -232,7 +232,7 @@ def test_fetch_account_by_email_fallback(): second.scalar_one_or_none.return_value = expected_account session.execute.side_effect = [first, second] - result = _fetch_account_by_email(session, "Mixed@Test.com") + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) assert result is expected_account assert session.execute.call_count == 2