diff --git a/api/commands/account.py b/api/commands/account.py index 0d99ce7a0fa..7f4f0a744f3 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -25,7 +25,7 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - account = AccountService.get_account_by_email_with_case_fallback(email.strip()) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip()) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -67,7 +67,7 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - account = AccountService.get_account_by_email_with_case_fallback(email.strip()) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip()) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 912eb26574c..ccbe9405fe5 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -15,6 +15,7 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) +from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import EmailStr, extract_remote_ip from libs.helper import timezone as validate_timezone_string @@ -100,7 +101,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - account = AccountService.get_account_by_email_with_case_fallback(args.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -175,7 +176,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: raise EmailAlreadyInUseError() diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d82f63c11db..061c29a13a2 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -82,7 +82,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - account = AccountService.get_account_by_email_with_case_fallback(args.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email) token = AccountService.send_reset_password_email( account=account, @@ -180,7 +180,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: account = db.session.merge(account) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 78d1583fde9..670d1c7818d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -224,7 +224,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, user_info.email) return account diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index b3230d77e69..4ea77e04b96 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -131,7 +131,7 @@ def _normalize_invitee_emails(emails: list[str]) -> list[str]: def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int: new_member_count = 0 for email in emails: - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: new_member_count += 1 continue diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 0ff225271df..5653fbae432 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -193,7 +193,7 @@ class WorkspaceMembersApi(Resource): raise BadRequest(str(exc)) normalized_email = body.email.lower() - member = AccountService.get_account_by_email_with_case_fallback(normalized_email) + member = AccountService.get_account_by_email_with_case_fallback(db.session, normalized_email) if member is None: # invite_new_member just created or fetched this account. raise RuntimeError("invited member missing from DB after invite") diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index d0e023e40ee..ecc91113c32 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -69,7 +69,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - account = AccountService.get_account_by_email_with_case_fallback(request_email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, request_email) if account is None: raise AuthenticationFailedError() else: @@ -168,7 +168,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: account = db.session.merge(account) diff --git a/api/services/account_service.py b/api/services/account_service.py index 21b5f1eedba..80411dd288e 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -14,7 +14,6 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import get_valid_language, language_timezone_mapping -from core.db.session_factory import session_factory from events.tenant_event import tenant_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -981,19 +980,18 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str) -> Account | None: + def get_account_by_email_with_case_fallback(session: Session | scoped_session, email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - with session_factory.create_session() as session: - account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() - if account or email == email.lower(): - return account + account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() + return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1958,7 +1956,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) requires_setup = False if not account: diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 2b63d9171e9..6ecc8eb8bc9 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -35,7 +35,7 @@ class WebAppAuthService: @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: raise AccountNotFoundError() @@ -55,7 +55,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: return None diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index bb7921a5f45..109332e16c9 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -270,10 +270,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 014c1588fee..812aa299c1b 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -165,10 +165,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index d043c0d413a..d87afb87669 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -494,10 +494,7 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 2c6a9902401..d568a1c0b04 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -4,7 +4,7 @@ from __future__ import annotations import base64 from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -57,7 +57,7 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com") + mock_get_account.assert_called_once_with(ANY, "User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -177,7 +177,7 @@ class TestForgotPasswordResetApi: response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com") + mock_get_account.assert_called_once_with(ANY, "User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index e419428ca66..1600fcda50d 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -692,12 +692,7 @@ def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): second.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first, second] - mock_factory = MagicMock() - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - - with patch("services.account_service.session_factory", mock_factory): - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 3b5c6cc9bd6..c748fc0962e 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1821,7 +1821,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with("newuser@example.com") + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "newuser@example.com") def test_invite_new_member_normalizes_new_account_email( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1865,7 +1865,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with(mixed_email) + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, mixed_email) mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with( mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal" @@ -1923,7 +1923,7 @@ class TestRegisterService: mock_tenant, mock_existing_account, "normal", requires_setup=True ) mock_task_dependencies.delay.assert_called_once() - mock_lookup.assert_called_once_with("existing@example.com") + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "existing@example.com") def test_invite_existing_active_account_requires_acceptance_before_joining( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies