From 721a0ddb280c8511f5a68f4dc2167b375ecf7bc2 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Fri, 19 Dec 2025 16:24:50 +0800 Subject: [PATCH] oauth email lower --- api/controllers/console/auth/oauth.py | 23 +++++- .../controllers/console/auth/test_oauth.py | 74 ++++++++++++++++++- 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 7ad1e56373..3c948f068d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -118,7 +118,10 @@ class OAuthCallback(Resource): invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) - if invitation_email != user_info.email: + invitation_email_normalized = ( + invitation_email.lower() if isinstance(invitation_email, str) else invitation_email + ) + if invitation_email_normalized != user_info.email.lower(): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") @@ -172,7 +175,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> if not account: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() + account = _fetch_account_by_email(session, user_info.email) return account @@ -193,8 +196,9 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): tenant_was_created.send(new_tenant) if not account: + normalized_email = user_info.email.lower() if not FeatureService.get_system_features().is_allow_register: - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountRegisterError( description=( "This email account has been deleted within the past " @@ -205,7 +209,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" account = RegisterService.register( - email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider + email=normalized_email, + name=account_name, + password=None, + open_id=user_info.id, + provider=provider, ) # Set interface language @@ -221,3 +229,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): AccountService.link_account_integrate(provider, user_info.id, account) return account + + +def _fetch_account_by_email(session: Session, email: str) -> Account | None: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account + return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 399caf8c4d..8cd3e69c53 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -6,6 +6,7 @@ from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, OAuthLogin, + _fetch_account_by_email, _generate_account, _get_account_by_openid_or_email, get_oauth_providers, @@ -215,6 +216,34 @@ class TestOAuthCallback: assert status_code == 400 assert response["error"] == expected_error + @patch("controllers.console.auth.oauth.dify_config") + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.redirect") + def test_invitation_comparison_is_case_insensitive( + self, + mock_redirect, + mock_register_service, + mock_get_providers, + mock_config, + resource, + app, + oauth_setup, + ): + mock_config.CONSOLE_WEB_URL = "http://localhost:3000" + oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo( + id="123", name="Test User", email="User@Example.com" + ) + mock_get_providers.return_value = {"github": oauth_setup["provider"]} + mock_register_service.is_valid_invite_token.return_value = True + mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"} + + with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"): + resource.get("github") + + mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123") + mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123") + @pytest.mark.parametrize( ("account_status", "expected_redirect"), [ @@ -411,7 +440,7 @@ class TestAccountGeneration: assert result == mock_account mock_account_model.get_by_openid.assert_called_once_with("github", "123") - # Test fallback to email + # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None mock_session_instance = MagicMock() mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account @@ -420,6 +449,20 @@ class TestAccountGeneration: result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account + def test_fetch_account_by_email_fallback(self): + mock_session = MagicMock() + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] + + result = _fetch_account_by_email(mock_session, "Case@Test.com") + + assert result == expected_account + assert mock_session.execute.call_count == 2 + @pytest.mark.parametrize( ("allow_register", "existing_account", "should_create"), [ @@ -465,6 +508,35 @@ class TestAccountGeneration: mock_register_service.register.assert_called_once_with( email="test@example.com", name="Test User", password=None, open_id="123", provider="github" ) + else: + mock_register_service.register.assert_not_called() + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + @patch("controllers.console.auth.oauth.db") + def test_should_register_with_lowercase_email( + self, + mock_db, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app, + ): + user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "en-US"}): + _generate_account("github", user_info) + + mock_register_service.register.assert_called_once_with( + email="upper@example.com", name="Test User", password=None, open_id="123", provider="github" + ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @patch("controllers.console.auth.oauth.TenantService")