oauth email lower

This commit is contained in:
hjlarry 2025-12-19 16:24:50 +08:00
parent 739dfd894f
commit 721a0ddb28
2 changed files with 92 additions and 5 deletions

View File

@ -118,7 +118,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token) invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation: if invitation:
invitation_email = invitation.get("email", None) invitation_email = invitation.get("email", None)
if invitation_email != user_info.email: invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@ -172,7 +175,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account: if not account:
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() account = _fetch_account_by_email(session, user_info.email)
return account return account
@ -193,8 +196,9 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(new_tenant) tenant_was_created.send(new_tenant)
if not account: if not account:
normalized_email = user_info.email.lower()
if not FeatureService.get_system_features().is_allow_register: if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError( raise AccountRegisterError(
description=( description=(
"This email account has been deleted within the past " "This email account has been deleted within the past "
@ -205,7 +209,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
raise AccountRegisterError(description=("Invalid email or password")) raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify" account_name = user_info.name or "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
) )
# Set interface language # Set interface language
@ -221,3 +229,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
AccountService.link_account_integrate(provider, user_info.id, account) AccountService.link_account_integrate(provider, user_info.id, account)
return account return account
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()

View File

@ -6,6 +6,7 @@ from flask import Flask
from controllers.console.auth.oauth import ( from controllers.console.auth.oauth import (
OAuthCallback, OAuthCallback,
OAuthLogin, OAuthLogin,
_fetch_account_by_email,
_generate_account, _generate_account,
_get_account_by_openid_or_email, _get_account_by_openid_or_email,
get_oauth_providers, get_oauth_providers,
@ -215,6 +216,34 @@ class TestOAuthCallback:
assert status_code == 400 assert status_code == 400
assert response["error"] == expected_error assert response["error"] == expected_error
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.redirect")
def test_invitation_comparison_is_case_insensitive(
self,
mock_redirect,
mock_register_service,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
id="123", name="Test User", email="User@Example.com"
)
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_register_service.is_valid_invite_token.return_value = True
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
resource.get("github")
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
@pytest.mark.parametrize( @pytest.mark.parametrize(
("account_status", "expected_redirect"), ("account_status", "expected_redirect"),
[ [
@ -411,7 +440,7 @@ class TestAccountGeneration:
assert result == mock_account assert result == mock_account
mock_account_model.get_by_openid.assert_called_once_with("github", "123") mock_account_model.get_by_openid.assert_called_once_with("github", "123")
# Test fallback to email # Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock() mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
@ -420,6 +449,20 @@ class TestAccountGeneration:
result = _get_account_by_openid_or_email("github", user_info) result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account assert result == mock_account
def test_fetch_account_by_email_fallback(self):
mock_session = MagicMock()
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = _fetch_account_by_email(mock_session, "Case@Test.com")
assert result == expected_account
assert mock_session.execute.call_count == 2
@pytest.mark.parametrize( @pytest.mark.parametrize(
("allow_register", "existing_account", "should_create"), ("allow_register", "existing_account", "should_create"),
[ [
@ -465,6 +508,35 @@ class TestAccountGeneration:
mock_register_service.register.assert_called_once_with( mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github" email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
) )
else:
mock_register_service.register.assert_not_called()
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_register_with_lowercase_email(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "en-US"}):
_generate_account("github", user_info)
mock_register_service.register.assert_called_once_with(
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.TenantService")