oauth email lower

This commit is contained in:
hjlarry 2025-12-19 16:24:50 +08:00
parent 739dfd894f
commit 721a0ddb28
2 changed files with 92 additions and 5 deletions

View File

@ -118,7 +118,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@ -172,7 +175,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
account = _fetch_account_by_email(session, user_info.email)
return account
@ -193,8 +196,9 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
tenant_was_created.send(new_tenant)
if not account:
normalized_email = user_info.email.lower()
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@ -205,7 +209,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
)
# Set interface language
@ -221,3 +229,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
AccountService.link_account_integrate(provider, user_info.id, account)
return account
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account or email == email.lower():
return account
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()

View File

@ -6,6 +6,7 @@ from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
OAuthLogin,
_fetch_account_by_email,
_generate_account,
_get_account_by_openid_or_email,
get_oauth_providers,
@ -215,6 +216,34 @@ class TestOAuthCallback:
assert status_code == 400
assert response["error"] == expected_error
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.redirect")
def test_invitation_comparison_is_case_insensitive(
self,
mock_redirect,
mock_register_service,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
id="123", name="Test User", email="User@Example.com"
)
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_register_service.is_valid_invite_token.return_value = True
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
resource.get("github")
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
@pytest.mark.parametrize(
("account_status", "expected_redirect"),
[
@ -411,7 +440,7 @@ class TestAccountGeneration:
assert result == mock_account
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
# Test fallback to email
# Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
@ -420,6 +449,20 @@ class TestAccountGeneration:
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
def test_fetch_account_by_email_fallback(self):
mock_session = MagicMock()
first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None
expected_account = MagicMock()
second_result = MagicMock()
second_result.scalar_one_or_none.return_value = expected_account
mock_session.execute.side_effect = [first_result, second_result]
result = _fetch_account_by_email(mock_session, "Case@Test.com")
assert result == expected_account
assert mock_session.execute.call_count == 2
@pytest.mark.parametrize(
("allow_register", "existing_account", "should_create"),
[
@ -465,6 +508,35 @@ class TestAccountGeneration:
mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
)
else:
mock_register_service.register.assert_not_called()
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_register_with_lowercase_email(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "en-US"}):
_generate_account("github", user_info)
mock_register_service.register.assert_called_once_with(
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.TenantService")