rm duplicated _get_invitation_with_case_fallback

This commit is contained in:
hjlarry 2025-12-22 14:12:08 +08:00
parent c8d7a4e10f
commit 61f962b728
7 changed files with 63 additions and 58 deletions

View File

@ -1,5 +1,3 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
@ -67,7 +65,7 @@ class ActivateCheckApi(Resource):
workspaceId = args.workspace_id
token = args.token
invitation = _get_invitation_with_case_fallback(workspaceId, args.email, token)
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@ -103,7 +101,7 @@ class ActivateApi(Resource):
args = ActivatePayload.model_validate(console_ns.payload)
normalized_request_email = args.email.lower() if args.email else None
invitation = _get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
@ -122,13 +120,3 @@ class ActivateApi(Resource):
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}
def _get_invitation_with_case_fallback(
workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation = RegisterService.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return RegisterService.get_invitation_if_token_valid(workspace_id, normalized_email, token)

View File

@ -103,7 +103,7 @@ class LoginApi(Resource):
invite_token = args.invite_token
invitation: dict[str, Any] | None = None
if invite_token:
invitation = _get_invitation_with_case_fallback(None, request_email, invite_token)
invitation = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation is None:
invite_token = None
@ -322,17 +322,6 @@ class RefreshTokenApi(Resource):
return {"result": "fail", "message": str(e)}, 401
def _get_invitation_with_case_fallback(
workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation = RegisterService.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return RegisterService.get_invitation_if_token_valid(workspace_id, normalized_email, token)
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
if account or email == email.lower():

View File

@ -1509,6 +1509,16 @@ class RegisterService:
invitation: dict = json.loads(data)
return invitation
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
) -> dict[str, Any] | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
def _generate_refresh_token(length: int = 64):
token = secrets.token_hex(length)

View File

@ -8,7 +8,7 @@ This module tests the account activation mechanism including:
- Initial login after activation
"""
from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
@ -40,7 +40,7 @@ class TestActivateCheckApi:
"tenant": tenant,
}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
"""
Test checking valid invitation token.
@ -66,7 +66,7 @@ class TestActivateCheckApi:
assert response["data"]["workspace_id"] == "workspace-123"
assert response["data"]["email"] == "invitee@example.com"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
"""
Test checking invalid invitation token.
@ -88,7 +88,7 @@ class TestActivateCheckApi:
# Assert
assert response["is_valid"] is False
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
"""
Test checking token without workspace ID.
@ -109,7 +109,7 @@ class TestActivateCheckApi:
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
"""
Test checking token without email parameter.
@ -130,7 +130,7 @@ class TestActivateCheckApi:
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
"""Ensure token validation uses lowercase emails."""
mock_get_invitation.return_value = mock_invitation
@ -142,10 +142,7 @@ class TestActivateCheckApi:
response = api.get()
assert response["is_valid"] is True
assert mock_get_invitation.call_args_list == [
call("workspace-123", "Invitee@Example.com", "valid_token"),
call("workspace-123", "invitee@example.com", "valid_token"),
]
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
class TestActivateApi:
@ -194,7 +191,7 @@ class TestActivateApi:
}
return token_pair
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
@ -249,7 +246,7 @@ class TestActivateApi:
mock_db.session.commit.assert_called_once()
mock_login.assert_called_once()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_activation_with_invalid_token(self, mock_get_invitation, app):
"""
Test account activation with invalid token.
@ -278,7 +275,7 @@ class TestActivateApi:
with pytest.raises(AlreadyActivateError):
api.post()
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
@ -331,7 +328,7 @@ class TestActivateApi:
("es-ES", "Europe/Madrid"),
],
)
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
@ -381,7 +378,7 @@ class TestActivateApi:
assert mock_account.interface_language == language
assert mock_account.timezone == timezone
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
@ -428,7 +425,7 @@ class TestActivateApi:
assert response["data"]["refresh_token"] == "refresh_token"
assert response["data"]["csrf_token"] == "csrf_token"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
@ -472,7 +469,7 @@ class TestActivateApi:
assert response["result"] == "success"
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
@patch("controllers.console.auth.activate.db")
@patch("controllers.console.auth.activate.AccountService.login")
@ -507,8 +504,5 @@ class TestActivateApi:
response = api.post()
assert response["result"] == "success"
assert mock_get_invitation.call_args_list == [
call("workspace-123", "Invitee@Example.com", "valid_token"),
call("workspace-123", "invitee@example.com", "valid_token"),
]
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")

View File

@ -34,7 +34,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invalid_email_with_registration_allowed(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):
@ -67,7 +67,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_wrong_password_returns_error(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
):
@ -100,7 +100,7 @@ class TestAuthenticationSecurity:
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invalid_email_with_registration_disabled(
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
):

View File

@ -76,7 +76,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@ -128,7 +128,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.AccountService.login")
@ -182,7 +182,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
"""
Test login rejection when rate limit is exceeded.
@ -230,7 +230,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
def test_login_fails_with_invalid_credentials(
@ -269,7 +269,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
def test_login_fails_for_banned_account(
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
@ -298,7 +298,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
@patch("controllers.console.auth.login.FeatureService.get_system_features")
@ -343,7 +343,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
"""
Test login failure when invitation email doesn't match login email.
@ -374,7 +374,7 @@ class TestLoginApi:
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
@patch("controllers.console.auth.login.AccountService.authenticate")
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
@patch("controllers.console.auth.login.TenantService.get_join_tenants")

View File

@ -1554,6 +1554,30 @@ class TestRegisterService:
# Verify results
assert result is None
def test_get_invitation_with_case_fallback_returns_initial_match(self):
"""Fallback helper should return the initial invitation when present."""
invitation = {"workspace_id": "tenant-456"}
with patch(
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
) as mock_get:
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
assert result == invitation
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
"""Fallback helper should retry with lowercase email when needed."""
invitation = {"workspace_id": "tenant-456"}
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
mock_get.side_effect = [None, invitation]
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
assert result == invitation
assert mock_get.call_args_list == [
(("tenant-456", "User@Test.com", "token-123"),),
(("tenant-456", "user@test.com", "token-123"),),
]
# ==================== Helper Method Tests ====================
def test_get_invitation_token_key(self):