mirror of https://github.com/langgenius/dify.git
oauth email lower
This commit is contained in:
parent
739dfd894f
commit
721a0ddb28
|
|
@ -118,7 +118,10 @@ class OAuthCallback(Resource):
|
|||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
if invitation:
|
||||
invitation_email = invitation.get("email", None)
|
||||
if invitation_email != user_info.email:
|
||||
invitation_email_normalized = (
|
||||
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
|
||||
)
|
||||
if invitation_email_normalized != user_info.email.lower():
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
|
@ -172,7 +175,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
|||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
account = _fetch_account_by_email(session, user_info.email)
|
||||
|
||||
return account
|
||||
|
||||
|
|
@ -193,8 +196,9 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
normalized_email = user_info.email.lower()
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
"This email account has been deleted within the past "
|
||||
|
|
@ -205,7 +209,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
|
|
@ -221,3 +229,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from flask import Flask
|
|||
from controllers.console.auth.oauth import (
|
||||
OAuthCallback,
|
||||
OAuthLogin,
|
||||
_fetch_account_by_email,
|
||||
_generate_account,
|
||||
_get_account_by_openid_or_email,
|
||||
get_oauth_providers,
|
||||
|
|
@ -215,6 +216,34 @@ class TestOAuthCallback:
|
|||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_invitation_comparison_is_case_insensitive(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_register_service,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
|
||||
id="123", name="Test User", email="User@Example.com"
|
||||
)
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_register_service.is_valid_invite_token.return_value = True
|
||||
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
|
||||
resource.get("github")
|
||||
|
||||
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
|
|
@ -411,7 +440,7 @@ class TestAccountGeneration:
|
|||
assert result == mock_account
|
||||
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||
|
||||
# Test fallback to email
|
||||
# Test fallback to email lookup
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
|
|
@ -420,6 +449,20 @@ class TestAccountGeneration:
|
|||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
|
||||
def test_fetch_account_by_email_fallback(self):
|
||||
mock_session = MagicMock()
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = _fetch_account_by_email(mock_session, "Case@Test.com")
|
||||
|
||||
assert result == expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_register", "existing_account", "should_create"),
|
||||
[
|
||||
|
|
@ -465,6 +508,35 @@ class TestAccountGeneration:
|
|||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
):
|
||||
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US"}):
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
|
|
|
|||
Loading…
Reference in New Issue