From 61f962b7285c0493f27f3e6c38524d96217b3405 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Mon, 22 Dec 2025 14:12:08 +0800 Subject: [PATCH] rm duplicated _get_invitation_with_case_fallback --- api/controllers/console/auth/activate.py | 16 ++------- api/controllers/console/auth/login.py | 13 +------ api/services/account_service.py | 10 ++++++ .../console/auth/test_account_activation.py | 36 ++++++++----------- .../auth/test_authentication_security.py | 6 ++-- .../console/auth/test_login_logout.py | 16 ++++----- .../services/test_account_service.py | 24 +++++++++++++ 7 files changed, 63 insertions(+), 58 deletions(-) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 87df67f85c..6bdec9c163 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,5 +1,3 @@ -from typing import Any - from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator @@ -67,7 +65,7 @@ class ActivateCheckApi(Resource): workspaceId = args.workspace_id token = args.token - invitation = _get_invitation_with_case_fallback(workspaceId, args.email, token) + invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token) if invitation: data = invitation.get("data", {}) tenant = invitation.get("tenant", None) @@ -103,7 +101,7 @@ class ActivateApi(Resource): args = ActivatePayload.model_validate(console_ns.payload) normalized_request_email = args.email.lower() if args.email else None - invitation = _get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) + invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) if invitation is None: raise AlreadyActivateError() @@ -122,13 +120,3 @@ class ActivateApi(Resource): token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success", "data": token_pair.model_dump()} - - -def _get_invitation_with_case_fallback( - workspace_id: str | None, email: str | None, token: str -) -> dict[str, Any] | None: - invitation = RegisterService.get_invitation_if_token_valid(workspace_id, email, token) - if invitation or not email or email == email.lower(): - return invitation - normalized_email = email.lower() - return RegisterService.get_invitation_if_token_valid(workspace_id, normalized_email, token) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index e8541b91bb..c1d1a9311d 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -103,7 +103,7 @@ class LoginApi(Resource): invite_token = args.invite_token invitation: dict[str, Any] | None = None if invite_token: - invitation = _get_invitation_with_case_fallback(None, request_email, invite_token) + invitation = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) if invitation is None: invite_token = None @@ -322,17 +322,6 @@ class RefreshTokenApi(Resource): return {"result": "fail", "message": str(e)}, 401 -def _get_invitation_with_case_fallback( - workspace_id: str | None, email: str | None, token: str -) -> dict[str, Any] | None: - invitation = RegisterService.get_invitation_if_token_valid(workspace_id, email, token) - if invitation or not email or email == email.lower(): - return invitation - - normalized_email = email.lower() - return RegisterService.get_invitation_if_token_valid(workspace_id, normalized_email, token) - - def _get_account_with_case_fallback(email: str): account = AccountService.get_user_through_email(email) if account or email == email.lower(): diff --git a/api/services/account_service.py b/api/services/account_service.py index 65bcc6de3b..1fea297eb2 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1509,6 +1509,16 @@ class RegisterService: invitation: dict = json.loads(data) return invitation + @classmethod + def get_invitation_with_case_fallback( + cls, workspace_id: str | None, email: str | None, token: str + ) -> dict[str, Any] | None: + invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) + if invitation or not email or email == email.lower(): + return invitation + normalized_email = email.lower() + return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token) + def _generate_refresh_token(length: int = 64): token = secrets.token_hex(length) diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index e1f618cc60..b8574a5127 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -8,7 +8,7 @@ This module tests the account activation mechanism including: - Initial login after activation """ -from unittest.mock import MagicMock, call, patch +from unittest.mock import MagicMock, patch import pytest from flask import Flask @@ -40,7 +40,7 @@ class TestActivateCheckApi: "tenant": tenant, } - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation): """ Test checking valid invitation token. @@ -66,7 +66,7 @@ class TestActivateCheckApi: assert response["data"]["workspace_id"] == "workspace-123" assert response["data"]["email"] == "invitee@example.com" - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_invalid_invitation_token(self, mock_get_invitation, app): """ Test checking invalid invitation token. @@ -88,7 +88,7 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is False - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation): """ Test checking token without workspace ID. @@ -109,7 +109,7 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token") - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation): """ Test checking token without email parameter. @@ -130,7 +130,7 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token") - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation): """Ensure token validation uses lowercase emails.""" mock_get_invitation.return_value = mock_invitation @@ -142,10 +142,7 @@ class TestActivateCheckApi: response = api.get() assert response["is_valid"] is True - assert mock_get_invitation.call_args_list == [ - call("workspace-123", "Invitee@Example.com", "valid_token"), - call("workspace-123", "invitee@example.com", "valid_token"), - ] + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") class TestActivateApi: @@ -194,7 +191,7 @@ class TestActivateApi: } return token_pair - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.AccountService.login") @@ -249,7 +246,7 @@ class TestActivateApi: mock_db.session.commit.assert_called_once() mock_login.assert_called_once() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_activation_with_invalid_token(self, mock_get_invitation, app): """ Test account activation with invalid token. @@ -278,7 +275,7 @@ class TestActivateApi: with pytest.raises(AlreadyActivateError): api.post() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.AccountService.login") @@ -331,7 +328,7 @@ class TestActivateApi: ("es-ES", "Europe/Madrid"), ], ) - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.AccountService.login") @@ -381,7 +378,7 @@ class TestActivateApi: assert mock_account.interface_language == language assert mock_account.timezone == timezone - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.AccountService.login") @@ -428,7 +425,7 @@ class TestActivateApi: assert response["data"]["refresh_token"] == "refresh_token" assert response["data"]["csrf_token"] == "csrf_token" - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.AccountService.login") @@ -472,7 +469,7 @@ class TestActivateApi: assert response["result"] == "success" mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token") - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") @patch("controllers.console.auth.activate.AccountService.login") @@ -507,8 +504,5 @@ class TestActivateApi: response = api.post() assert response["result"] == "success" - assert mock_get_invitation.call_args_list == [ - call("workspace-123", "Invitee@Example.com", "valid_token"), - call("workspace-123", "invitee@example.com", "valid_token"), - ] + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index eb21920117..cb4fe40944 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -34,7 +34,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_allowed( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): @@ -67,7 +67,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_wrong_password_returns_error( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db ): @@ -100,7 +100,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_disabled( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index fda0fecc70..560971206f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -76,7 +76,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -128,7 +128,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -182,7 +182,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login rejection when rate limit is exceeded. @@ -230,7 +230,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") def test_login_fails_with_invalid_credentials( @@ -269,7 +269,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app @@ -298,7 +298,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.FeatureService.get_system_features") @@ -343,7 +343,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login failure when invitation email doesn't match login email. @@ -374,7 +374,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.TenantService.get_join_tenants") diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index e357eda805..011fc02669 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1554,6 +1554,30 @@ class TestRegisterService: # Verify results assert result is None + def test_get_invitation_with_case_fallback_returns_initial_match(self): + """Fallback helper should return the initial invitation when present.""" + invitation = {"workspace_id": "tenant-456"} + with patch( + "services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation + ) as mock_get: + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123") + + def test_get_invitation_with_case_fallback_retries_with_lowercase(self): + """Fallback helper should retry with lowercase email when needed.""" + invitation = {"workspace_id": "tenant-456"} + with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get: + mock_get.side_effect = [None, invitation] + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + assert mock_get.call_args_list == [ + (("tenant-456", "User@Test.com", "token-123"),), + (("tenant-456", "user@test.com", "token-123"),), + ] + # ==================== Helper Method Tests ==================== def test_get_invitation_token_key(self):