diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 74319d1336..9d92c8bd79 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -56,6 +56,12 @@ from models.enums import CreatorUserRole from models.model import UploadFile from services.account_service import AccountService from services.billing_service import BillingService +from services.entities.auth_entities import ( + ChangeEmailNewEmailToken, + ChangeEmailNewEmailVerifiedToken, + ChangeEmailOldEmailToken, + ChangeEmailOldEmailVerifiedToken, +) from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -620,8 +626,8 @@ class ChangeEmailSendEmailApi(Resource): language = "zh-Hans" else: language = "en-US" - account = None - user_email = None + account = current_user + user_email = current_user.email email_for_sending = args.email.lower() # Default to the initial phase; any legacy/unexpected client input is # coerced back to `old_email` so we never trust the caller to declare @@ -636,24 +642,18 @@ class ChangeEmailSendEmailApi(Resource): if reset_data is None: raise InvalidTokenError() - # The token used to request a new-email code must come from the - # old-email verification step. This prevents the bypass described - # in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here. - token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) - if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED: + if not isinstance(reset_data, ChangeEmailOldEmailVerifiedToken): raise InvalidTokenError() - user_email = reset_data.get("email", "") + if not reset_data.is_bound_to_account(current_user.id): + raise InvalidTokenError() + user_email = reset_data.email if user_email.lower() != current_user.email.lower(): raise InvalidEmailError() - - user_email = current_user.email else: - account = AccountService.get_account_by_email_with_case_fallback(args.email) - if account is None: - raise AccountNotFound() - email_for_sending = account.email - user_email = account.email + if email_for_sending != current_user.email.lower(): + raise InvalidEmailError() + email_for_sending = current_user.email token = AccountService.send_change_email_email( account=account, @@ -674,6 +674,7 @@ class ChangeEmailCheckApi(Resource): @login_required @account_initialization_required def post(self): + current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} args = ChangeEmailValidityPayload.model_validate(payload) @@ -686,42 +687,26 @@ class ChangeEmailCheckApi(Resource): token_data = AccountService.get_change_email_data(args.token) if token_data is None: raise InvalidTokenError() + if not token_data.is_bound_to_account(current_user.id): + raise InvalidTokenError() - token_email = token_data.get("email") - normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + normalized_token_email = token_data.email.lower() if user_email != normalized_token_email: raise InvalidEmailError() - if args.code != token_data.get("code"): + if args.code != token_data.code: AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() - # Only advance tokens that were minted by the matching send-code step; - # refuse tokens that have already progressed or lack a phase marker so - # the chain `old_email -> old_email_verified -> new_email -> new_email_verified` - # is strictly enforced. - phase_transitions = { - AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, - AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, - } - token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) - if not isinstance(token_phase, str): - raise InvalidTokenError() - refreshed_phase = phase_transitions.get(token_phase) - if refreshed_phase is None: + if isinstance(token_data, ChangeEmailOldEmailToken | ChangeEmailNewEmailToken): + refreshed_token_data = token_data.promote() + else: raise InvalidTokenError() # Verified, revoke the first token AccountService.revoke_change_email_token(args.token) - # Refresh token data by generating a new token that carries the - # upgraded phase so later steps can check it. - _, new_token = AccountService.generate_change_email_token( - user_email, - code=args.code, - old_email=token_data.get("old_email"), - additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase}, - ) + new_token = AccountService.generate_change_email_token(refreshed_token_data, current_user) AccountService.reset_change_email_error_rate_limit(user_email) return {"is_valid": True, "email": normalized_token_email, "token": new_token} @@ -746,27 +731,22 @@ class ChangeEmailResetApi(Resource): if not AccountService.check_email_unique(normalized_new_email): raise EmailAlreadyInUseError() + current_user, _ = current_account_with_tenant() reset_data = AccountService.get_change_email_data(args.token) if not reset_data: raise InvalidTokenError() + if not reset_data.is_bound_to_account(current_user.id): + raise InvalidTokenError() - # Only tokens that completed both verification phases may be used to - # change the email. This closes GHSA-4q3w-q5mc-45rq where a token from - # the initial send-code step could be replayed directly here. - token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) - if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED: + if not isinstance(reset_data, ChangeEmailNewEmailVerifiedToken): raise InvalidTokenError() # Bind the new email to the token that was mailed and verified, so a # verified token cannot be reused with a different `new_email` value. - token_email = reset_data.get("email") - normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email - if normalized_token_email != normalized_new_email: + if reset_data.email.lower() != normalized_new_email: raise InvalidTokenError() - old_email = reset_data.get("old_email", "") - current_user, _ = current_account_with_tenant() - if current_user.email.lower() != old_email.lower(): + if current_user.email.lower() != reset_data.old_email.lower(): raise AccountNotFound() # Revoke only after all checks pass so failed attempts don't burn a diff --git a/api/libs/helper.py b/api/libs/helper.py index 04900f385c..57e808a408 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,7 +16,7 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, ConfigDict, TypeAdapter, with_config from pydantic.functional_validators import AfterValidator from typing_extensions import TypedDict @@ -33,13 +33,29 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@with_config(ConfigDict(extra="allow")) class _TokenData(TypedDict, total=False): + """Shared baseline token payload. + + `extra='allow'` keeps TokenManager from silently stripping business- + specific metadata keys while still validating the common auth fields. + Business flows that need stronger guarantees should validate again at + their own boundary with a dedicated Pydantic model. + + For the change-email flow specifically, `email_change_phase` is the + discriminator used by `services.entities.auth_entities.ChangeEmailTokenData`. + It is declared here so the shared token adapter can still provide baseline + validation for the state-machine key without taking over the full business + model. + """ + account_id: str | None email: str token_type: str code: str old_email: str phase: str + email_change_phase: str _token_data_adapter: TypeAdapter[_TokenData] = TypeAdapter(_TokenData) @@ -466,7 +482,7 @@ class TokenManager: raise ValueError("Account or email must be provided") account_id = account.id if account else None - account_email = account.email if account else email + account_email = email if email is not None else account.email if account else None if account_id: old_token = cls._get_current_token_for_account(account_id, token_type) @@ -508,8 +524,7 @@ class TokenManager: if token_data_json is None: logger.warning("%s token %s not found with key %s", token_type, token, key) return None - token_data = dict(_token_data_adapter.validate_json(token_data_json)) - return token_data + return dict(_token_data_adapter.validate_json(token_data_json)) @classmethod def _get_current_token_for_account(cls, account_id: str, token_type: str) -> str | None: diff --git a/api/services/account_service.py b/api/services/account_service.py index 6533526b60..e020831180 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -7,7 +7,7 @@ from datetime import UTC, datetime, timedelta from hashlib import sha256 from typing import Any, TypedDict, cast -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, TypeAdapter, ValidationError from sqlalchemy import delete, func, select, update from core.db.session_factory import session_factory @@ -46,6 +46,12 @@ from models.account import ( ) from models.model import DifySetup from services.billing_service import BillingService +from services.entities.auth_entities import ( + ChangeEmailNewEmailToken, + ChangeEmailOldEmailToken, + ChangeEmailPhase, + ChangeEmailTokenData, +) from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, @@ -84,6 +90,8 @@ from tasks.mail_reset_password_task import ( logger = logging.getLogger(__name__) +_change_email_token_adapter: TypeAdapter[ChangeEmailTokenData] = TypeAdapter(ChangeEmailTokenData) + class InvitationDetailDict(TypedDict): account: Account @@ -113,13 +121,10 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: - # Phase-bound token metadata for the change-email flow. Tokens carry the - # current phase so that downstream endpoints can enforce proper progression - CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase" - CHANGE_EMAIL_PHASE_OLD = "old_email" - CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified" - CHANGE_EMAIL_PHASE_NEW = "new_email" - CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified" + CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD_EMAIL + CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_EMAIL_VERIFIED + CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW_EMAIL + CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_EMAIL_VERIFIED reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) @@ -583,31 +588,42 @@ class AccountService: @classmethod def send_change_email_email( cls, - account: Account | None = None, + account: Account, email: str | None = None, old_email: str | None = None, language: str = "en-US", phase: str | None = None, ): - account_email = account.email if account else email - if account_email is None: - raise ValueError("Email must be provided.") + account_email = email if email is not None else account.email if not phase: raise ValueError("phase must be provided.") if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW): raise ValueError("phase must be one of old_email or new_email.") + if old_email is None: + raise ValueError("old_email must be provided.") if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) - code, token = cls.generate_change_email_token( - account_email, - account, - old_email=old_email, - additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase}, - ) + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + token_data: ChangeEmailTokenData + if phase == cls.CHANGE_EMAIL_PHASE_OLD: + token_data = ChangeEmailOldEmailToken( + account_id=account.id, + email=account_email, + old_email=old_email, + code=code, + ) + else: + token_data = ChangeEmailNewEmailToken( + account_id=account.id, + email=account_email, + old_email=old_email, + code=code, + ) + token = cls.generate_change_email_token(token_data, account) send_change_mail_task.delay( language=language, @@ -735,20 +751,16 @@ class AccountService: @classmethod def generate_change_email_token( cls, - email: str, - account: Account | None = None, - code: str | None = None, - old_email: str | None = None, - additional_data: dict[str, Any] = {}, - ): - if not code: - code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) - additional_data["code"] = code - additional_data["old_email"] = old_email + token_data: ChangeEmailTokenData, + account: Account, + ) -> str: token = TokenManager.generate_token( - account=account, email=email, token_type="change_email", additional_data=additional_data + account=account, + email=token_data.email, + token_type="change_email", + additional_data=token_data.to_token_manager_payload(), ) - return code, token + return token @classmethod def generate_owner_transfer_token( @@ -791,8 +803,15 @@ class AccountService: return TokenManager.get_token_data(token, "email_register") @classmethod - def get_change_email_data(cls, token: str) -> dict[str, Any] | None: - return TokenManager.get_token_data(token, "change_email") + def get_change_email_data(cls, token: str) -> ChangeEmailTokenData | None: + token_data = TokenManager.get_token_data(token, "change_email") + if token_data is None: + return None + try: + return _change_email_token_adapter.validate_python(token_data) + except ValidationError: + logger.warning("change_email token %s has invalid payload", token, exc_info=True) + return None @classmethod def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None: diff --git a/api/services/entities/auth_entities.py b/api/services/entities/auth_entities.py index e3fb249692..79c0b63758 100644 --- a/api/services/entities/auth_entities.py +++ b/api/services/entities/auth_entities.py @@ -1,6 +1,7 @@ from enum import StrEnum, auto +from typing import Annotated, Literal -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from libs.helper import EmailStr from libs.password import valid_password @@ -20,6 +21,24 @@ class LoginFailureReason(StrEnum): LOGIN_RATE_LIMITED = auto() +class ChangeEmailPhase(StrEnum): + """Change-email token state machine. + + Allowed transitions: + + `OLD_EMAIL -> OLD_EMAIL_VERIFIED -> NEW_EMAIL -> NEW_EMAIL_VERIFIED` + + The flow starts by sending a code to the current email address. Only a + token in `OLD_EMAIL_VERIFIED` may request the new-email code, and only a + token in `NEW_EMAIL_VERIFIED` may perform the final email reset. + """ + + OLD_EMAIL = "old_email" + OLD_EMAIL_VERIFIED = "old_email_verified" + NEW_EMAIL = "new_email" + NEW_EMAIL_VERIFIED = "new_email_verified" + + class LoginPayloadBase(BaseModel): email: EmailStr password: str @@ -45,3 +64,122 @@ class ForgotPasswordResetPayload(BaseModel): @classmethod def validate_password(cls, value: str) -> str: return valid_password(value) + + +class ChangeEmailTokenBase(BaseModel): + """Stored change-email token payload. + + The discriminator lives in `email_change_phase`; callers use the concrete + model type to decide which transitions are legal. + + The full progression is: + + `old_email -> old_email_verified -> new_email -> new_email_verified` + + Every state is bound to the initiating `account_id` so the change-email + flow cannot be replayed across accounts. + """ + + token_type: Literal["change_email"] = "change_email" + account_id: str = Field(min_length=1) + email: EmailStr + old_email: EmailStr + code: str = Field(min_length=1) + + model_config = ConfigDict(extra="forbid") + + def to_token_manager_payload(self) -> dict[str, str]: + return self.model_dump(exclude={"token_type", "account_id", "email"}) + + def is_bound_to_account(self, account_id: str) -> bool: + return self.account_id == account_id + + +class _ChangeEmailOldAddressMixin(ChangeEmailTokenBase): + """States whose `email` must still be the account's current address.""" + + @model_validator(mode="after") + def validate_old_address_binding(self) -> "_ChangeEmailOldAddressMixin": + if self.email.lower() != self.old_email.lower(): + raise ValueError("old-email token payload must bind email to old_email") + return self + + +class ChangeEmailOldEmailToken(_ChangeEmailOldAddressMixin): + """Phase-1 token minted when sending a code to the old email address. + + This token proves only that the flow started for the current account. It + must not unlock the new-email send step or the final reset step until the + old-email verification code has been checked. + """ + + email_change_phase: Literal[ChangeEmailPhase.OLD_EMAIL] = ChangeEmailPhase.OLD_EMAIL + + def promote(self) -> "ChangeEmailOldEmailVerifiedToken": + """Advance to the state that is allowed to request the new-email code.""" + return ChangeEmailOldEmailVerifiedToken( + **self.model_dump(exclude={"email_change_phase"}), + email_change_phase=ChangeEmailPhase.OLD_EMAIL_VERIFIED, + ) + + +class ChangeEmailOldEmailVerifiedToken(_ChangeEmailOldAddressMixin): + """Token returned after the old email verification code succeeds. + + The token used to request a new-email code must come from this state. This + blocks the GHSA-4q3w-q5mc-45rq bypass where a phase-1 token was replayed to + skip the old-email verification step. + """ + + email_change_phase: Literal[ChangeEmailPhase.OLD_EMAIL_VERIFIED] = ChangeEmailPhase.OLD_EMAIL_VERIFIED + + +class ChangeEmailNewEmailToken(ChangeEmailTokenBase): + """Token minted when sending a code to the target new email address. + + At this point the account binding is already fixed, but the new address has + not been verified yet, so the token may only be promoted by a successful + new-email verification code check. + """ + + email_change_phase: Literal[ChangeEmailPhase.NEW_EMAIL] = ChangeEmailPhase.NEW_EMAIL + + def promote(self) -> "ChangeEmailNewEmailVerifiedToken": + """Advance to the only state that may perform the final email reset.""" + return ChangeEmailNewEmailVerifiedToken( + **self.model_dump(exclude={"email_change_phase"}), + email_change_phase=ChangeEmailPhase.NEW_EMAIL_VERIFIED, + ) + + +class ChangeEmailNewEmailVerifiedToken(ChangeEmailTokenBase): + """Final verified token for the change-email flow. + + Only this state may change the account email, and the reset endpoint must + additionally require that the request's `new_email` matches this token's + `email` so a verified token for address A cannot be replayed for address B. + """ + + email_change_phase: Literal[ChangeEmailPhase.NEW_EMAIL_VERIFIED] = ChangeEmailPhase.NEW_EMAIL_VERIFIED + + +# Tokens that can still advance by verifying a code. +ChangeEmailPendingTokenData = Annotated[ + ChangeEmailOldEmailToken | ChangeEmailNewEmailToken, + Field(discriminator="email_change_phase"), +] + +# Tokens that already completed a verification step. +ChangeEmailVerifiedTokenData = Annotated[ + ChangeEmailOldEmailVerifiedToken | ChangeEmailNewEmailVerifiedToken, + Field(discriminator="email_change_phase"), +] + +# Complete change-email token state machine. +ChangeEmailTokenData = Annotated[ + ChangeEmailOldEmailToken + | ChangeEmailOldEmailVerifiedToken + | ChangeEmailNewEmailToken + | ChangeEmailNewEmailVerifiedToken, + Field(discriminator="email_change_phase"), +] diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 4b4f968c8f..95d7493b71 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -13,6 +13,12 @@ from controllers.console.workspace.account import ( ) from models import Account, AccountStatus from services.account_service import AccountService +from services.entities.auth_entities import ( + ChangeEmailNewEmailToken, + ChangeEmailNewEmailVerifiedToken, + ChangeEmailOldEmailToken, + ChangeEmailOldEmailVerifiedToken, +) @pytest.fixture @@ -39,7 +45,66 @@ def _set_logged_in_user(account: Account): g._current_tenant = account.current_tenant +def _build_change_email_token( + phase: str, + *, + account_id: str = "acc", + email: str, + old_email: str, + code: str = "1234", +): + token_kwargs = { + "account_id": account_id, + "email": email, + "old_email": old_email, + "code": code, + } + if phase == AccountService.CHANGE_EMAIL_PHASE_OLD: + return ChangeEmailOldEmailToken(**token_kwargs) + if phase == AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED: + return ChangeEmailOldEmailVerifiedToken(**token_kwargs) + if phase == AccountService.CHANGE_EMAIL_PHASE_NEW: + return ChangeEmailNewEmailToken(**token_kwargs) + if phase == AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED: + return ChangeEmailNewEmailVerifiedToken(**token_kwargs) + raise AssertionError(f"Unsupported phase for test helper: {phase}") + + class TestChangeEmailSend: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_old_email_phase_when_request_email_does_not_match_current_user( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_current_account, + mock_db, + app: Flask, + ): + from controllers.console.auth.error import InvalidEmailError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("current@example.com", "acc1"), None) + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "other@example.com", "language": "en-US", "phase": "old_email"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidEmailError): + ChangeEmailSendEmailApi().post() + + mock_send_email.assert_not_called() + @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @@ -63,10 +128,12 @@ class TestChangeEmailSend: mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) - mock_get_change_data.return_value = { - "email": "current@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, - } + mock_get_change_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + account_id="acc1", + email="current@example.com", + old_email="current@example.com", + ) mock_send_email.return_value = "token-abc" with app.test_request_context( @@ -79,7 +146,7 @@ class TestChangeEmailSend: assert response == {"result": "success", "data": "token-abc"} mock_send_email.assert_called_once_with( - account=None, + account=mock_account, email="new@example.com", old_email="current@example.com", language="en-US", @@ -115,10 +182,12 @@ class TestChangeEmailSend: mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) - mock_get_change_data.return_value = { - "email": "current@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, - } + mock_get_change_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_OLD, + account_id="acc1", + email="current@example.com", + old_email="current@example.com", + ) with app.test_request_context( "/account/change-email", @@ -131,6 +200,49 @@ class TestChangeEmailSend: mock_send_email.assert_not_called() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_new_email_phase_when_token_account_id_does_not_match_current_user( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app: Flask, + ): + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + account_id="other-account", + email="current@example.com", + old_email="current@example.com", + ) + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "new@example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailSendEmailApi().post() + + mock_send_email.assert_not_called() + class TestChangeEmailValidity: @patch("controllers.console.wraps.db") @@ -161,13 +273,13 @@ class TestChangeEmailValidity: mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = { - "email": "user@example.com", - "code": "1234", - "old_email": "old@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, - } - mock_generate_token.return_value = (None, "new-token") + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_OLD, + account_id="acc2", + email="user@example.com", + old_email="user@example.com", + ) + mock_generate_token.return_value = "new-token" with app.test_request_context( "/account/change-email/validity", @@ -182,12 +294,13 @@ class TestChangeEmailValidity: mock_add_rate.assert_not_called() mock_revoke_token.assert_called_once_with("token-123") mock_generate_token.assert_called_once_with( - "user@example.com", - code="1234", - old_email="old@example.com", - additional_data={ - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, - }, + _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + account_id="acc2", + email="user@example.com", + old_email="user@example.com", + ), + mock_account, ) mock_reset_rate.assert_called_once_with("user@example.com") mock_csrf.assert_called_once() @@ -219,13 +332,13 @@ class TestChangeEmailValidity: mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = { - "email": "new@example.com", - "code": "1234", - "old_email": "old@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, - } - mock_generate_token.return_value = (None, "new-verified-token") + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW, + account_id="acc", + email="new@example.com", + old_email="old@example.com", + ) + mock_generate_token.return_value = "new-verified-token" with app.test_request_context( "/account/change-email/validity", @@ -237,12 +350,13 @@ class TestChangeEmailValidity: assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"} mock_generate_token.assert_called_once_with( - "new@example.com", - code="1234", - old_email="old@example.com", - additional_data={ - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, - }, + _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + account_id="acc", + email="new@example.com", + old_email="old@example.com", + ), + mock_current_account.return_value[0], ) @patch("controllers.console.wraps.db") @@ -255,7 +369,7 @@ class TestChangeEmailValidity: @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") @patch("libs.login.check_csrf_token", return_value=None) @patch("controllers.console.wraps.FeatureService.get_system_features") - def test_should_reject_validity_when_token_phase_is_unknown( + def test_should_reject_validity_when_token_is_already_verified( self, mock_features, mock_csrf, @@ -269,23 +383,22 @@ class TestChangeEmailValidity: mock_db, app: Flask, ): - """A token whose phase marker is a string but not a known transition must be rejected.""" from controllers.console.auth.error import InvalidTokenError mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = { - "email": "user@example.com", - "code": "1234", - "old_email": "old@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else", - } + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + account_id="acc", + email="old@example.com", + old_email="old@example.com", + ) with app.test_request_context( "/account/change-email/validity", method="POST", - json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + json={"email": "old@example.com", "code": "1234", "token": "token-123"}, ): _set_logged_in_user(_build_account("tester@example.com", "tester")) with pytest.raises(InvalidTokenError): @@ -304,7 +417,7 @@ class TestChangeEmailValidity: @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") @patch("libs.login.check_csrf_token", return_value=None) @patch("controllers.console.wraps.FeatureService.get_system_features") - def test_should_reject_validity_when_token_has_no_phase( + def test_should_reject_validity_when_token_account_id_does_not_match_current_user( self, mock_features, mock_csrf, @@ -318,22 +431,22 @@ class TestChangeEmailValidity: mock_db, app: Flask, ): - """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" from controllers.console.auth.error import InvalidTokenError mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = { - "email": "user@example.com", - "code": "1234", - "old_email": "old@example.com", - } + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW, + account_id="other-account", + email="new@example.com", + old_email="old@example.com", + ) with app.test_request_context( "/account/change-email/validity", method="POST", - json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + json={"email": "new@example.com", "code": "1234", "token": "token-123"}, ): _set_logged_in_user(_build_account("tester@example.com", "tester")) with pytest.raises(InvalidTokenError): @@ -373,11 +486,12 @@ class TestChangeEmailReset: mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - mock_get_data.return_value = { - "email": "new@example.com", - "old_email": "OLD@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, - } + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + account_id="acc3", + email="new@example.com", + old_email="OLD@example.com", + ) mock_account_after_update = _build_account("new@example.com", "acc3-updated") mock_update_account.return_value = mock_account_after_update @@ -428,13 +542,12 @@ class TestChangeEmailReset: mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - # Simulate a token straight out of step #1 (phase=old_email) — exactly - # the replay used in the advisory PoC. - mock_get_data.return_value = { - "email": "old@example.com", - "old_email": "old@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, - } + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_OLD, + account_id="acc3", + email="old@example.com", + old_email="old@example.com", + ) with app.test_request_context( "/account/change-email/reset", @@ -481,11 +594,12 @@ class TestChangeEmailReset: mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - mock_get_data.return_value = { - "email": "verified@example.com", - "old_email": "old@example.com", - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, - } + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + account_id="acc3", + email="verified@example.com", + old_email="old@example.com", + ) with app.test_request_context( "/account/change-email/reset", @@ -500,6 +614,57 @@ class TestChangeEmailReset: mock_update_account.assert_not_called() mock_send_notify.assert_not_called() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_account_id_does_not_match_current_user( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app: Flask, + ): + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + account_id="other-account", + email="new@example.com", + old_email="old@example.com", + ) + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "new@example.com", "token": "token-verified"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + class TestAccountServiceSendChangeEmailEmail: """Service-level coverage for the phase-bound changes in `send_change_email_email`.""" @@ -507,7 +672,8 @@ class TestAccountServiceSendChangeEmailEmail: def test_should_raise_value_error_for_invalid_phase(self): with pytest.raises(ValueError, match="phase must be one of"): AccountService.send_change_email_email( - email="user@example.com", + account=_build_account("old@example.com", "acc"), + email="new@example.com", old_email="user@example.com", phase="old_email_verified", ) @@ -515,33 +681,77 @@ class TestAccountServiceSendChangeEmailEmail: @patch("services.account_service.send_change_mail_task") @patch("services.account_service.AccountService.change_email_rate_limiter") @patch("services.account_service.AccountService.generate_change_email_token") - def test_should_stamp_phase_into_generated_token( + def test_should_bind_account_id_and_target_email_into_generated_token( self, mock_generate_token, mock_rate_limiter, mock_mail_task, ): mock_rate_limiter.is_rate_limited.return_value = False - mock_generate_token.return_value = ("123456", "the-token") + mock_generate_token.return_value = "the-token" + account = _build_account("old@example.com", "acc-123") returned = AccountService.send_change_email_email( - email="user@example.com", - old_email="user@example.com", + account=account, + email="new@example.com", + old_email="old@example.com", language="en-US", phase=AccountService.CHANGE_EMAIL_PHASE_NEW, ) assert returned == "the-token" mock_generate_token.assert_called_once_with( - "user@example.com", - None, - old_email="user@example.com", - additional_data={ - AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, - }, + _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW, + account_id="acc-123", + email="new@example.com", + old_email="old@example.com", + code=mock_mail_task.delay.call_args.kwargs["code"], + ), + account, ) - mock_mail_task.delay.assert_called_once() - mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com") + mock_mail_task.delay.assert_called_once_with( + language="en-US", + to="new@example.com", + code=mock_mail_task.delay.call_args.kwargs["code"], + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, + ) + mock_rate_limiter.increment_rate_limit.assert_called_once_with("new@example.com") + + +class TestAccountServiceGetChangeEmailData: + @patch("services.account_service.TokenManager.get_token_data") + def test_should_parse_change_email_token_into_discriminated_union_model(self, mock_get_token_data): + mock_get_token_data.return_value = { + "token_type": "change_email", + "account_id": "acc-1", + "email": "new@example.com", + "old_email": "old@example.com", + "code": "654321", + "email_change_phase": AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + + token_data = AccountService.get_change_email_data("token-123") + + assert token_data == _build_change_email_token( + AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + account_id="acc-1", + email="new@example.com", + old_email="old@example.com", + code="654321", + ) + + @patch("services.account_service.TokenManager.get_token_data") + def test_should_reject_change_email_token_without_account_id(self, mock_get_token_data): + mock_get_token_data.return_value = { + "token_type": "change_email", + "email": "new@example.com", + "old_email": "old@example.com", + "code": "654321", + "email_change_phase": AccountService.CHANGE_EMAIL_PHASE_NEW, + } + + assert AccountService.get_change_email_data("token-123") is None class TestAccountDeletionFeedback: diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index df0d2bda49..aa58db81da 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -388,6 +388,10 @@ class TestChangeEmailApis: with ( app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.account.current_account_with_tenant", + return_value=(MagicMock(id="acc-1"), "t1"), + ), patch.object( type(console_ns), "payload", @@ -400,7 +404,11 @@ class TestChangeEmailApis: ), patch( "controllers.console.workspace.account.AccountService.get_change_email_data", - return_value={"email": "a@test.com", "code": "y"}, + return_value=MagicMock( + email="a@test.com", + code="y", + is_bound_to_account=MagicMock(return_value=True), + ), ), ): with pytest.raises(EmailCodeError): diff --git a/api/tests/unit_tests/libs/test_token_manager.py b/api/tests/unit_tests/libs/test_token_manager.py index 832210c7f2..bbe8a7e30b 100644 --- a/api/tests/unit_tests/libs/test_token_manager.py +++ b/api/tests/unit_tests/libs/test_token_manager.py @@ -1,58 +1,139 @@ """ -Regression tests for the `_TokenData` TypedDict used by -`libs.helper.TokenManager`. +Regression tests for `libs.helper.TokenManager`. -These tests guard the contract that every field a caller writes via -`generate_token` survives the TypedDict-validated round-trip performed -by `get_token_data`. Specifically, the `phase` field that the console -and web `forgot-password` + `change-email` controllers depend on for -the security check introduced in PR #35425 (GHSA-4q3w-q5mc-45rq) must -be preserved — otherwise downstream `if data.get("phase", "") != "reset"` -checks always fail with `InvalidTokenError`. +`TokenManager` is the storage primitive shared by multiple auth flows, so it +must preserve every metadata field written by the caller. Business-specific +validation now happens at the callsite boundary (for example, +`AccountService.get_change_email_data`), not inside `TokenManager`. """ import json +from types import SimpleNamespace -# pyright: reportPrivateUsage=false -from libs.helper import _token_data_adapter +import pytest +from pydantic import ValidationError + +import libs.helper as helper_module +from libs.helper import TokenManager -def test_token_data_adapter_preserves_phase_field() -> None: - """`phase` written by callers like generate_reset_password_token must - survive the TypedDict-validated round-trip in get_token_data. +def _build_fake_redis(storage: dict[str, str]): + def store_value(key: str, _ttl: int, value: str) -> bool: + storage[key] = value + return True - Regression: PR #34380 introduced `_TokenData` but did not list - `phase`, so the TypeAdapter silently dropped it and the security - gate from PR #35425 (GHSA-4q3w-q5mc-45rq) always failed. - """ - payload = { - "account_id": None, - "email": "user@example.com", - "token_type": "reset_password", - "code": "123456", - "phase": "reset", - } - data = dict(_token_data_adapter.validate_json(json.dumps(payload))) + def load_value(key: str) -> str | None: + return storage.get(key) - assert data.get("phase") == "reset", ( - "phase field was stripped by the _TokenData TypedDict adapter; " - "the forgot-password phase-bound check (PR #35425) will always fail." + return SimpleNamespace( + setex=store_value, + get=load_value, + delete=lambda *_args, **_kwargs: None, ) -def test_token_data_adapter_preserves_change_email_payload() -> None: - """Sanity round-trip for the change-email flow: every field set by - `generate_change_email_token` must come back, including the phase - string the controller branches on.""" - payload = { - "account_id": "acc-1", - "email": "new@example.com", - "token_type": "change_email", - "code": "654321", - "old_email": "old@example.com", - "phase": "verify_old_email", - } - data = dict(_token_data_adapter.validate_json(json.dumps(payload))) +def test_token_manager_roundtrip_preserves_untyped_metadata_keys(monkeypatch: pytest.MonkeyPatch) -> None: + """`TokenManager` must round-trip arbitrary metadata keys without silently + dropping fields such as `phase`, `email_change_phase`, or future auth + payload extensions. + """ + storage: dict[str, str] = {} + monkeypatch.setattr(helper_module, "redis_client", _build_fake_redis(storage)) + + token = TokenManager.generate_token( + email="user@example.com", + token_type="change_email", + additional_data={ + "code": "654321", + "old_email": "old@example.com", + "phase": "legacy-phase", + "email_change_phase": "old_email", + "custom_marker": "preserve-me", + }, + ) + + data = TokenManager.get_token_data(token, "change_email") + + assert data is not None + assert data.get("phase") == "legacy-phase" + assert data.get("email_change_phase") == "old_email" + assert data.get("custom_marker") == "preserve-me" + + +def test_token_manager_roundtrip_uses_explicit_email_with_account(monkeypatch: pytest.MonkeyPatch) -> None: + """When both `account` and `email` are supplied, the token should bind the + stable `account_id` from the account and the target email from the explicit + email argument. + """ + + storage: dict[str, str] = {} + monkeypatch.setattr(helper_module, "redis_client", _build_fake_redis(storage)) + + account = SimpleNamespace(id="acc-1", email="old@example.com") + + token = TokenManager.generate_token( + account=account, + email="new@example.com", + token_type="change_email", + additional_data={ + "code": "654321", + "old_email": "old@example.com", + "email_change_phase": "new_email", + }, + ) + + data = TokenManager.get_token_data(token, "change_email") + + assert data is not None + assert data.get("account_id") == "acc-1" + assert data.get("email") == "new@example.com" assert data.get("old_email") == "old@example.com" - assert data.get("phase") == "verify_old_email" + assert data.get("email_change_phase") == "new_email" + + +def test_token_manager_roundtrip_still_validates_declared_fields(monkeypatch: pytest.MonkeyPatch) -> None: + """Unknown fields should be preserved, but declared baseline fields should + still be validated by `_token_data_adapter`. + """ + + storage = { + "change_email:token:token-123": json.dumps( + { + "token_type": "change_email", + "account_id": "acc-1", + "email": ["not-a-string"], + "code": "654321", + "old_email": "old@example.com", + "email_change_phase": "old_email", + } + ) + } + monkeypatch.setattr(helper_module, "redis_client", _build_fake_redis(storage)) + + with pytest.raises(ValidationError): + TokenManager.get_token_data("token-123", "change_email") + + +def test_token_manager_roundtrip_validates_email_change_phase_as_string(monkeypatch: pytest.MonkeyPatch) -> None: + """`email_change_phase` is part of the shared baseline schema, so obviously + malformed discriminator values should fail before the change-email-specific + union parsing at the callsite boundary. + """ + + storage = { + "change_email:token:token-456": json.dumps( + { + "token_type": "change_email", + "account_id": "acc-1", + "email": "new@example.com", + "code": "654321", + "old_email": "old@example.com", + "email_change_phase": ["not-a-string"], + } + ) + } + monkeypatch.setattr(helper_module, "redis_client", _build_fake_redis(storage)) + + with pytest.raises(ValidationError): + TokenManager.get_token_data("token-456", "change_email")