From 2e874313eceb81c8dd3047aa3437a427f279ca29 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Sun, 21 Dec 2025 18:41:21 +0800 Subject: [PATCH] improve active email lower --- api/controllers/console/auth/activate.py | 21 ++++++++++++++----- .../console/auth/test_account_activation.py | 12 ++++++++--- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index c700f62d62..87df67f85c 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,3 +1,5 @@ +from typing import Any + from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator @@ -63,10 +65,9 @@ class ActivateCheckApi(Resource): args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore workspaceId = args.workspace_id - reg_email = args.email.lower() if args.email else None token = args.token - invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) + invitation = _get_invitation_with_case_fallback(workspaceId, args.email, token) if invitation: data = invitation.get("data", {}) tenant = invitation.get("tenant", None) @@ -101,12 +102,12 @@ class ActivateApi(Resource): def post(self): args = ActivatePayload.model_validate(console_ns.payload) - normalized_email = args.email.lower() if args.email else None - invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, normalized_email, args.token) + normalized_request_email = args.email.lower() if args.email else None + invitation = _get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args.workspace_id, normalized_email, args.token) + RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token) account = invitation["account"] account.name = args.name @@ -121,3 +122,13 @@ class ActivateApi(Resource): token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success", "data": token_pair.model_dump()} + + +def _get_invitation_with_case_fallback( + workspace_id: str | None, email: str | None, token: str +) -> dict[str, Any] | None: + invitation = RegisterService.get_invitation_if_token_valid(workspace_id, email, token) + if invitation or not email or email == email.lower(): + return invitation + normalized_email = email.lower() + return RegisterService.get_invitation_if_token_valid(workspace_id, normalized_email, token) diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index a9801ce0a9..e1f618cc60 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -8,7 +8,7 @@ This module tests the account activation mechanism including: - Initial login after activation """ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import pytest from flask import Flask @@ -142,7 +142,10 @@ class TestActivateCheckApi: response = api.get() assert response["is_valid"] is True - mock_get_invitation.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") + assert mock_get_invitation.call_args_list == [ + call("workspace-123", "Invitee@Example.com", "valid_token"), + call("workspace-123", "invitee@example.com", "valid_token"), + ] class TestActivateApi: @@ -504,5 +507,8 @@ class TestActivateApi: response = api.post() assert response["result"] == "success" - mock_get_invitation.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") + assert mock_get_invitation.call_args_list == [ + call("workspace-123", "Invitee@Example.com", "valid_token"), + call("workspace-123", "invitee@example.com", "valid_token"), + ] mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")