mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
oauth email lower
This commit is contained in:
parent
739dfd894f
commit
721a0ddb28
@ -118,7 +118,10 @@ class OAuthCallback(Resource):
|
|||||||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||||
if invitation:
|
if invitation:
|
||||||
invitation_email = invitation.get("email", None)
|
invitation_email = invitation.get("email", None)
|
||||||
if invitation_email != user_info.email:
|
invitation_email_normalized = (
|
||||||
|
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
|
||||||
|
)
|
||||||
|
if invitation_email_normalized != user_info.email.lower():
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||||
|
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||||
@ -172,7 +175,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
|||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
account = _fetch_account_by_email(session, user_info.email)
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
@ -193,8 +196,9 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||||||
tenant_was_created.send(new_tenant)
|
tenant_was_created.send(new_tenant)
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
|
normalized_email = user_info.email.lower()
|
||||||
if not FeatureService.get_system_features().is_allow_register:
|
if not FeatureService.get_system_features().is_allow_register:
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||||
raise AccountRegisterError(
|
raise AccountRegisterError(
|
||||||
description=(
|
description=(
|
||||||
"This email account has been deleted within the past "
|
"This email account has been deleted within the past "
|
||||||
@ -205,7 +209,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||||||
raise AccountRegisterError(description=("Invalid email or password"))
|
raise AccountRegisterError(description=("Invalid email or password"))
|
||||||
account_name = user_info.name or "Dify"
|
account_name = user_info.name or "Dify"
|
||||||
account = RegisterService.register(
|
account = RegisterService.register(
|
||||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
email=normalized_email,
|
||||||
|
name=account_name,
|
||||||
|
password=None,
|
||||||
|
open_id=user_info.id,
|
||||||
|
provider=provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set interface language
|
# Set interface language
|
||||||
@ -221,3 +229,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_account_by_email(session: Session, email: str) -> Account | None:
|
||||||
|
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||||
|
if account or email == email.lower():
|
||||||
|
return account
|
||||||
|
return session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from flask import Flask
|
|||||||
from controllers.console.auth.oauth import (
|
from controllers.console.auth.oauth import (
|
||||||
OAuthCallback,
|
OAuthCallback,
|
||||||
OAuthLogin,
|
OAuthLogin,
|
||||||
|
_fetch_account_by_email,
|
||||||
_generate_account,
|
_generate_account,
|
||||||
_get_account_by_openid_or_email,
|
_get_account_by_openid_or_email,
|
||||||
get_oauth_providers,
|
get_oauth_providers,
|
||||||
@ -215,6 +216,34 @@ class TestOAuthCallback:
|
|||||||
assert status_code == 400
|
assert status_code == 400
|
||||||
assert response["error"] == expected_error
|
assert response["error"] == expected_error
|
||||||
|
|
||||||
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
|
@patch("controllers.console.auth.oauth.RegisterService")
|
||||||
|
@patch("controllers.console.auth.oauth.redirect")
|
||||||
|
def test_invitation_comparison_is_case_insensitive(
|
||||||
|
self,
|
||||||
|
mock_redirect,
|
||||||
|
mock_register_service,
|
||||||
|
mock_get_providers,
|
||||||
|
mock_config,
|
||||||
|
resource,
|
||||||
|
app,
|
||||||
|
oauth_setup,
|
||||||
|
):
|
||||||
|
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||||
|
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
|
||||||
|
id="123", name="Test User", email="User@Example.com"
|
||||||
|
)
|
||||||
|
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||||
|
mock_register_service.is_valid_invite_token.return_value = True
|
||||||
|
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
|
||||||
|
|
||||||
|
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
|
||||||
|
resource.get("github")
|
||||||
|
|
||||||
|
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
|
||||||
|
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("account_status", "expected_redirect"),
|
("account_status", "expected_redirect"),
|
||||||
[
|
[
|
||||||
@ -411,7 +440,7 @@ class TestAccountGeneration:
|
|||||||
assert result == mock_account
|
assert result == mock_account
|
||||||
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||||
|
|
||||||
# Test fallback to email
|
# Test fallback to email lookup
|
||||||
mock_account_model.get_by_openid.return_value = None
|
mock_account_model.get_by_openid.return_value = None
|
||||||
mock_session_instance = MagicMock()
|
mock_session_instance = MagicMock()
|
||||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||||
@ -420,6 +449,20 @@ class TestAccountGeneration:
|
|||||||
result = _get_account_by_openid_or_email("github", user_info)
|
result = _get_account_by_openid_or_email("github", user_info)
|
||||||
assert result == mock_account
|
assert result == mock_account
|
||||||
|
|
||||||
|
def test_fetch_account_by_email_fallback(self):
|
||||||
|
mock_session = MagicMock()
|
||||||
|
first_result = MagicMock()
|
||||||
|
first_result.scalar_one_or_none.return_value = None
|
||||||
|
expected_account = MagicMock()
|
||||||
|
second_result = MagicMock()
|
||||||
|
second_result.scalar_one_or_none.return_value = expected_account
|
||||||
|
mock_session.execute.side_effect = [first_result, second_result]
|
||||||
|
|
||||||
|
result = _fetch_account_by_email(mock_session, "Case@Test.com")
|
||||||
|
|
||||||
|
assert result == expected_account
|
||||||
|
assert mock_session.execute.call_count == 2
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("allow_register", "existing_account", "should_create"),
|
("allow_register", "existing_account", "should_create"),
|
||||||
[
|
[
|
||||||
@ -465,6 +508,35 @@ class TestAccountGeneration:
|
|||||||
mock_register_service.register.assert_called_once_with(
|
mock_register_service.register.assert_called_once_with(
|
||||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
mock_register_service.register.assert_not_called()
|
||||||
|
|
||||||
|
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||||
|
@patch("controllers.console.auth.oauth.FeatureService")
|
||||||
|
@patch("controllers.console.auth.oauth.RegisterService")
|
||||||
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
@patch("controllers.console.auth.oauth.db")
|
||||||
|
def test_should_register_with_lowercase_email(
|
||||||
|
self,
|
||||||
|
mock_db,
|
||||||
|
mock_tenant_service,
|
||||||
|
mock_account_service,
|
||||||
|
mock_register_service,
|
||||||
|
mock_feature_service,
|
||||||
|
mock_get_account,
|
||||||
|
app,
|
||||||
|
):
|
||||||
|
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
|
||||||
|
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||||
|
mock_register_service.register.return_value = MagicMock()
|
||||||
|
|
||||||
|
with app.test_request_context(headers={"Accept-Language": "en-US"}):
|
||||||
|
_generate_account("github", user_info)
|
||||||
|
|
||||||
|
mock_register_service.register.assert_called_once_with(
|
||||||
|
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||||
|
)
|
||||||
|
|
||||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||||
@patch("controllers.console.auth.oauth.TenantService")
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user