mirror of https://github.com/langgenius/dify.git
add oauthuser flag for frontend when use oauth login
This commit is contained in:
parent
3505516e8e
commit
5ad435ef32
|
|
@ -124,7 +124,7 @@ class OAuthCallback(Resource):
|
|||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account = _generate_account(provider, user_info)
|
||||
account, oauth_new_user = _generate_account(provider, user_info)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
|
|
@ -159,7 +159,7 @@ class OAuthCallback(Resource):
|
|||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}?oauth_new_user={str(oauth_new_user).lower()}")
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
|
|
@ -177,9 +177,10 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
|||
return account
|
||||
|
||||
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
oauth_new_user = False
|
||||
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
|
|
@ -193,6 +194,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
raise AccountRegisterError(
|
||||
|
|
@ -220,4 +222,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
|||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
return account
|
||||
return account, oauth_new_user
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class TestOAuthCallback:
|
|||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_generate_account.return_value = oauth_setup["account"]
|
||||
mock_generate_account.return_value = (oauth_setup["account"], True)
|
||||
mock_account_service.login.return_value = oauth_setup["token_pair"]
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
|
|
@ -260,7 +260,7 @@ class TestOAuthCallback:
|
|||
account = MagicMock()
|
||||
account.status = account_status
|
||||
account.id = "123"
|
||||
mock_generate_account.return_value = account
|
||||
mock_generate_account.return_value = (account, False)
|
||||
|
||||
# Mock login for CLOSED status
|
||||
mock_token_pair = MagicMock()
|
||||
|
|
@ -296,7 +296,7 @@ class TestOAuthCallback:
|
|||
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = AccountStatus.PENDING
|
||||
mock_generate_account.return_value = mock_account
|
||||
mock_generate_account.return_value = (mock_account, False)
|
||||
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
|
|
@ -360,7 +360,7 @@ class TestOAuthCallback:
|
|||
closed_account.status = AccountStatus.CLOSED
|
||||
closed_account.id = "123"
|
||||
closed_account.name = "Closed Account"
|
||||
mock_generate_account.return_value = closed_account
|
||||
mock_generate_account.return_value = (closed_account, False)
|
||||
|
||||
# Mock successful login (current behavior)
|
||||
mock_token_pair = MagicMock()
|
||||
|
|
@ -458,8 +458,9 @@ class TestAccountGeneration:
|
|||
with pytest.raises(AccountRegisterError):
|
||||
_generate_account("github", user_info)
|
||||
else:
|
||||
result = _generate_account("github", user_info)
|
||||
result, oauth_new_user = _generate_account("github", user_info)
|
||||
assert result == mock_account
|
||||
assert oauth_new_user == should_create
|
||||
|
||||
if should_create:
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
|
|
@ -490,9 +491,10 @@ class TestAccountGeneration:
|
|||
mock_tenant_service.create_tenant.return_value = mock_new_tenant
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
result = _generate_account("github", user_info)
|
||||
result, oauth_new_user = _generate_account("github", user_info)
|
||||
|
||||
assert result == mock_account
|
||||
assert oauth_new_user is False
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||
mock_new_tenant, mock_account, role="owner"
|
||||
|
|
|
|||
Loading…
Reference in New Issue