diff --git a/api/commands.py b/api/commands.py index e24b1826ee..20ce22a6c7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -35,7 +35,7 @@ from libs.rsa import generate_key_pair from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile +from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.provider import Provider, ProviderModel from models.provider_ids import DatasourceProviderID, ToolProviderID @@ -64,8 +64,10 @@ def reset_password(email, new_password, password_confirm): if str(new_password).strip() != str(password_confirm).strip(): click.echo(click.style("Passwords do not match.", fg="red")) return + normalized_email = email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -86,7 +88,7 @@ def reset_password(email, new_password, password_confirm): base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -102,20 +104,22 @@ def reset_email(email, new_email, email_confirm): if str(new_email).strip() != str(email_confirm).strip(): click.echo(click.style("New emails do not match.", fg="red")) return + normalized_new_email = new_email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) return try: - email_validate(new_email) + email_validate(normalized_new_email) except: click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account.email = new_email + account.email = normalized_new_email click.echo(click.style("Email updated successfully.", fg="green")) @@ -660,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No return # Create account - email = email.strip() + email = email.strip().lower() if "@" not in email: click.echo(click.style("Invalid email address.", fg="red")) diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index fe70d930fb..cfc673880c 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -63,10 +63,9 @@ class ActivateCheckApi(Resource): args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore workspaceId = args.workspace_id - reg_email = args.email token = args.token - invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) + invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token) if invitation: data = invitation.get("data", {}) tenant = invitation.get("tenant", None) @@ -100,11 +99,12 @@ class ActivateApi(Resource): def post(self): args = ActivatePayload.model_validate(console_ns.payload) - invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token) + normalized_request_email = args.email.lower() if args.email else None + invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args.workspace_id, args.email, args.token) + RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token) account = invitation["account"] account.name = args.name diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index fa082c735d..c2a95ddad2 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource): @email_register_enabled def post(self): args = EmailRegisterSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource): if args.language in languages: language = args.language - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() - token = None - token = AccountService.send_email_register_email(email=args.email, account=account, language=language) + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource): def post(self): args = EmailRegisterValidityPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email) + is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email) if is_email_register_error_rate_limit: raise EmailRegisterLimitError() @@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_email_register_error_rate_limit(args.email) + AccountService.add_email_register_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource): user_email, code=args.code, additional_data={"phase": "register"} ) - AccountService.reset_email_register_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_email_register_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/email-register") @@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource): AccountService.revoke_email_register_token(args.token) email = register_data.get("email", "") + normalized_email = email.lower() with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: raise EmailAlreadyInUseError() else: - account = self._create_new_account(email, args.password_confirm) + account = self._create_new_account(normalized_email, args.password_confirm) if not account: raise AccountNotFoundError() token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} - def _create_new_account(self, email, password) -> Account | None: + def _create_new_account(self, email: str, password: str) -> Account | None: # Create new account if allowed account = None try: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 661f591182..394f205d93 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.console import console_ns @@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account from services.account_service import AccountService, TenantService from services.feature_service import FeatureService @@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource): @email_password_login_enabled def post(self): args = ForgotPasswordSendPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( account=account, - email=args.email, + email=normalized_email, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, ) @@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource): def post(self): args = ForgotPasswordCheckPayload.model_validate(console_ns.payload) - user_email = args.email + user_email = args.email.lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(args.email) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource): # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=args.code, additional_data={"phase": "reset"} + token_email, code=args.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/forgot-password/resets") @@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 4a52bf8abe..400df138b8 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -90,32 +90,38 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" args = LoginPayload.model_validate(console_ns.payload) + request_email = args.email + normalized_email = request_email.lower() - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email) + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() + invite_token = args.invite_token invitation_data: dict[str, Any] | None = None - if args.invite_token: - invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token) + if invite_token: + invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) + if invitation_data is None: + invite_token = None try: if invitation_data: data = invitation_data.get("data", {}) invitee_email = data.get("email") if data else None - if invitee_email != args.email: + invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email + if invitee_email_normalized != normalized_email: raise InvalidEmailError() - account = AccountService.authenticate(args.email, args.password, args.invite_token) - else: - account = AccountService.authenticate(args.email, args.password) + account = _authenticate_account_with_case_fallback( + request_email, normalized_email, args.password, invite_token + ) except services.errors.account.AccountLoginError: raise AccountBannedError() - except services.errors.account.AccountPasswordError: - AccountService.add_login_error_rate_limit(args.email) - raise AuthenticationFailedError() + except services.errors.account.AccountPasswordError as exc: + AccountService.add_login_error_rate_limit(normalized_email) + raise AuthenticationFailedError() from exc # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: @@ -130,7 +136,7 @@ class LoginApi(Resource): } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args.email) + AccountService.reset_login_error_rate_limit(normalized_email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource): @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): args = EmailPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() if args.language is not None and args.language == "zh-Hans": language = "zh-Hans" else: language = "en-US" try: - account = AccountService.get_user_through_email(args.email) + account = _get_account_with_case_fallback(args.email) except AccountRegisterError: raise AccountInFreezeError() token = AccountService.send_reset_password_email( - email=args.email, + email=normalized_email, account=account, language=language, is_allow_register=FeatureService.get_system_features().is_allow_register, @@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource): @console_ns.expect(console_ns.models[EmailPayload.__name__]) def post(self): args = EmailPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): @@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource): else: language = "en-US" try: - account = AccountService.get_user_through_email(args.email) + account = _get_account_with_case_fallback(args.email) except AccountRegisterError: raise AccountInFreezeError() if account is None: if FeatureService.get_system_features().is_allow_register: - token = AccountService.send_email_code_login_email(email=args.email, language=language) + token = AccountService.send_email_code_login_email(email=normalized_email, language=language) else: raise AccountNotFound() else: @@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource): def post(self): args = EmailCodeLoginPayload.model_validate(console_ns.payload) - user_email = args.email + original_email = args.email + user_email = original_email.lower() language = args.language token_data = AccountService.get_email_code_login_data(args.token) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args.email: + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if normalized_token_email != user_email: raise InvalidEmailError() if token_data["code"] != args.code: @@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource): AccountService.revoke_email_code_login_token(args.token) try: - account = AccountService.get_user_through_email(user_email) + account = _get_account_with_case_fallback(original_email) except AccountRegisterError: raise AccountInFreezeError() if account: @@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource): except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(args.email) + AccountService.reset_login_error_rate_limit(user_email) # Create response with cookies instead of returning tokens in body response = make_response({"result": "success"}) @@ -309,3 +320,22 @@ class RefreshTokenApi(Resource): return response except Exception as e: return {"result": "fail", "message": str(e)}, 401 + + +def _get_account_with_case_fallback(email: str): + account = AccountService.get_user_through_email(email) + if account or email == email.lower(): + return account + + return AccountService.get_user_through_email(email.lower()) + + +def _authenticate_account_with_case_fallback( + original_email: str, normalized_email: str, password: str, invite_token: str | None +): + try: + return AccountService.authenticate(original_email, password, invite_token) + except services.errors.account.AccountPasswordError: + if original_email == normalized_email: + raise + return AccountService.authenticate(normalized_email, password, invite_token) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index c20e83d36f..112e152432 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,7 +3,6 @@ import logging import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -118,7 +117,10 @@ class OAuthCallback(Resource): invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) - if invitation_email != user_info.email: + invitation_email_normalized = ( + invitation_email.lower() if isinstance(invitation_email, str) else invitation_email + ) + if invitation_email_normalized != user_info.email.lower(): return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") @@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> if not account: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account @@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, tenant_was_created.send(new_tenant) if not account: + normalized_email = user_info.email.lower() oauth_new_user = True if not FeatureService.get_system_features().is_allow_register: - if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountRegisterError( description=( "This email account has been deleted within the past " @@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, raise AccountRegisterError(description=("Invalid email or password")) account_name = user_info.name or "Dify" account = RegisterService.register( - email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider + email=normalized_email, + name=account_name, + password=None, + open_id=user_info.id, + provider=provider, ) # Set interface language diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 7fa02ae280..ed22ef045d 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -84,10 +84,11 @@ class SetupApi(Resource): raise NotInitValidateError() args = SetupRequestPayload.model_validate(console_ns.payload) + normalized_email = args.email.lower() # setup RegisterService.setup( - email=args.email, + email=normalized_email, name=args.name, password=args.password, ip_address=extract_remote_ip(request), diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 03ad0f423b..527aabbc3d 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -41,7 +41,7 @@ from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required -from models import Account, AccountIntegrate, InvitationCode +from models import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource): else: language = "en-US" account = None - user_email = args.email + user_email = None + email_for_sending = args.email.lower() if args.phase is not None and args.phase == "new_email": if args.token is None: raise InvalidTokenError() @@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") - if user_email != current_user.email: + if user_email.lower() != current_user.email.lower(): raise InvalidEmailError() + + user_email = current_user.email else: with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) if account is None: raise AccountNotFound() + email_for_sending = account.email + user_email = account.email token = AccountService.send_change_email_email( - account=account, email=args.email, old_email=user_email, language=language, phase=args.phase + account=account, + email=email_for_sending, + old_email=user_email, + language=language, + phase=args.phase, ) return {"result": "success", "data": token} @@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource): payload = console_ns.payload or {} args = ChangeEmailValidityPayload.model_validate(payload) - user_email = args.email + user_email = args.email.lower() - is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email) + is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email) if is_change_email_error_rate_limit: raise EmailChangeLimitError() @@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if user_email != normalized_token_email: raise InvalidEmailError() if args.code != token_data.get("code"): - AccountService.add_change_email_error_rate_limit(args.email) + AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource): user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} ) - AccountService.reset_change_email_error_rate_limit(args.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_change_email_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @console_ns.route("/account/change-email/reset") @@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource): def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) + normalized_new_email = args.new_email.lower() - if AccountService.is_account_in_freeze(args.new_email): + if AccountService.is_account_in_freeze(normalized_new_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args.new_email): + if not AccountService.check_email_unique(normalized_new_email): raise EmailAlreadyInUseError() reset_data = AccountService.get_change_email_data(args.token) @@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource): old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() - if current_user.email != old_email: + if current_user.email.lower() != old_email.lower(): raise AccountNotFound() - updated_account = AccountService.update_account_email(current_user, email=args.new_email) + updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) AccountService.send_change_email_completed_notify_email( - email=args.new_email, + email=normalized_new_email, ) return updated_account @@ -645,8 +657,9 @@ class CheckEmailUnique(Resource): def post(self): payload = console_ns.payload or {} args = CheckEmailUniquePayload.model_validate(payload) - if AccountService.is_account_in_freeze(args.email): + normalized_email = args.email.lower() + if AccountService.is_account_in_freeze(normalized_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(args.email): + if not AccountService.check_email_unique(normalized_email): raise EmailAlreadyInUseError() return {"result": "success"} diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 0142e14fb0..e9bd2b8f94 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource): raise WorkspaceMembersLimitExceeded() for invitee_email in invitee_emails: + normalized_invitee_email = invitee_email.lower() try: if not inviter.current_tenant: raise ValueError("No current tenant") token = RegisterService.invite_new_member( - inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + tenant=inviter.current_tenant, + email=invitee_email, + language=interface_language, + role=invitee_role, + inviter=inviter, ) - encoded_invitee_email = parse.quote(invitee_email) + encoded_invitee_email = parse.quote(normalized_invitee_email) invitation_results.append( { "status": "success", - "email": invitee_email, + "email": normalized_invitee_email, "url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}", } ) except AccountAlreadyInTenantError: invitation_results.append( - {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + {"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"} ) except Exception as e: - invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) + invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)}) return { "result": "success", diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 690b76655f..91d206f727 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy import select from sqlalchemy.orm import Session from controllers.common.schema import register_schema_models @@ -22,7 +21,7 @@ from controllers.web import web_ns from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import hash_password, valid_password -from models import Account +from models.account import Account from services.account_service import AccountService @@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource): def post(self): payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {}) + request_email = payload.email + normalized_email = request_email.lower() + ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() @@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource): language = "en-US" with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) token = None if account is None: raise AuthenticationFailedError() else: - token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language) + token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language) return {"result": "success", "data": token} @@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource): def post(self): payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {}) - user_email = payload.email + user_email = payload.email.lower() - is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email) + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email) if is_forgot_password_error_rate_limit: raise EmailPasswordResetLimitError() @@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource): if token_data is None: raise InvalidTokenError() - if user_email != token_data.get("email"): + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + + if user_email != normalized_token_email: raise InvalidEmailError() if payload.code != token_data.get("code"): - AccountService.add_forgot_password_error_rate_limit(payload.email) + AccountService.add_forgot_password_error_rate_limit(user_email) raise EmailCodeError() # Verified, revoke the first token @@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource): # Refresh token data by generating a new token _, new_token = AccountService.generate_reset_password_token( - user_email, code=payload.code, additional_data={"phase": "reset"} + token_email, code=payload.code, additional_data={"phase": "reset"} ) - AccountService.reset_forgot_password_error_rate_limit(payload.email) - return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + AccountService.reset_forgot_password_error_rate_limit(user_email) + return {"is_valid": True, "email": normalized_token_email, "token": new_token} @web_ns.route("/forgot-password/resets") @@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") with Session(db.engine) as session: - account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: self._update_existing_account(account, password_hashed, salt, session) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 5847f4ae3a..e8053acdfd 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource): ) args = parser.parse_args() - user_email = args["email"] + user_email = args["email"].lower() token_data = WebAppAuthService.get_email_code_login_data(args["token"]) if token_data is None: raise InvalidTokenError() - if token_data["email"] != args["email"]: + token_email = token_data.get("email") + if not isinstance(token_email, str): + raise InvalidEmailError() + normalized_token_email = token_email.lower() + if normalized_token_email != user_email: raise InvalidEmailError() if token_data["code"] != args["code"]: raise EmailCodeError() WebAppAuthService.revoke_email_code_login_token(args["token"]) - account = WebAppAuthService.get_user_through_email(user_email) + account = WebAppAuthService.get_user_through_email(token_email) if not account: raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - AccountService.reset_login_error_rate_limit(args["email"]) + AccountService.reset_login_error_rate_limit(user_email) response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) return response diff --git a/api/services/account_service.py b/api/services/account_service.py index d38c9d5a66..709ef749bc 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from hashlib import sha256 from typing import Any, cast from pydantic import BaseModel -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -748,6 +748,21 @@ class AccountService: cls.email_code_login_rate_limiter.increment_rate_limit(email) return token + @staticmethod + def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + """ + Retrieve an account by email and fall back to the lowercase email if the original lookup fails. + + This keeps backward compatibility for older records that stored uppercase emails while the + rest of the system gradually normalizes new inputs. + """ + query_session = session or db.session + account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + if account or email == email.lower(): + return account + + return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: return TokenManager.get_token_data(token, "email_code_login") @@ -1363,16 +1378,22 @@ class RegisterService: if not inviter: raise ValueError("Inviter is required") + normalized_email = email.lower() + """Invite new member""" with Session(db.engine) as session: - account = session.query(Account).filter_by(email=email).first() + account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") - name = email.split("@")[0] + name = normalized_email.split("@")[0] account = cls.register( - email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True + email=normalized_email, + name=name, + language=language, + status=AccountStatus.PENDING, + is_setup=True, ) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) @@ -1394,7 +1415,7 @@ class RegisterService: # send email send_invite_member_mail_task.delay( language=language, - to=email, + to=account.email, token=token, inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, @@ -1493,6 +1514,16 @@ class RegisterService: invitation: dict = json.loads(data) return invitation + @classmethod + def get_invitation_with_case_fallback( + cls, workspace_id: str | None, email: str | None, token: str + ) -> dict[str, Any] | None: + invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) + if invitation or not email or email == email.lower(): + return invitation + normalized_email = email.lower() + return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token) + def _generate_refresh_token(length: int = 64): token = secrets.token_hex(length) diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 9bd797a45f..5ca0b63001 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -12,6 +12,7 @@ from libs.passport import PassportService from libs.password import compare_password from models import Account, AccountStatus from models.model import App, EndUser, Site +from services.account_service import AccountService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError @@ -32,7 +33,7 @@ class WebAppAuthService: @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - account = db.session.query(Account).filter_by(email=email).first() + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: raise AccountNotFoundError() @@ -52,7 +53,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = db.session.query(Account).where(Account.email == email).first() + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: return None diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index da21e0e358..d3e864a75a 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -40,7 +40,7 @@ class TestActivateCheckApi: "tenant": tenant, } - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation): """ Test checking valid invitation token. @@ -66,7 +66,7 @@ class TestActivateCheckApi: assert response["data"]["workspace_id"] == "workspace-123" assert response["data"]["email"] == "invitee@example.com" - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_invalid_invitation_token(self, mock_get_invitation, app): """ Test checking invalid invitation token. @@ -88,7 +88,7 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is False - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation): """ Test checking token without workspace ID. @@ -109,7 +109,7 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token") - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation): """ Test checking token without email parameter. @@ -130,6 +130,20 @@ class TestActivateCheckApi: assert response["is_valid"] is True mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") + def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation): + """Ensure token validation uses lowercase emails.""" + mock_get_invitation.return_value = mock_invitation + + with app.test_request_context( + "/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token" + ): + api = ActivateCheckApi() + response = api.get() + + assert response["is_valid"] is True + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + class TestActivateApi: """Test cases for account activation endpoint.""" @@ -212,7 +226,7 @@ class TestActivateApi: mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") mock_db.session.commit.assert_called_once() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") def test_activation_with_invalid_token(self, mock_get_invitation, app): """ Test account activation with invalid token. @@ -241,7 +255,7 @@ class TestActivateApi: with pytest.raises(AlreadyActivateError): api.post() - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_sets_interface_theme( @@ -290,7 +304,7 @@ class TestActivateApi: ("es-ES", "Europe/Madrid"), ], ) - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_with_different_locales( @@ -336,7 +350,7 @@ class TestActivateApi: assert mock_account.interface_language == language assert mock_account.timezone == timezone - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_returns_success_response( @@ -376,7 +390,7 @@ class TestActivateApi: # Assert assert response == {"result": "success"} - @patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.activate.RegisterService.revoke_token") @patch("controllers.console.auth.activate.db") def test_activation_without_workspace_id( @@ -415,3 +429,37 @@ class TestActivateApi: # Assert assert response["result"] == "success" mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token") + + @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") + @patch("controllers.console.auth.activate.RegisterService.revoke_token") + @patch("controllers.console.auth.activate.db") + def test_activation_normalizes_email_before_lookup( + self, + mock_db, + mock_revoke_token, + mock_get_invitation, + app, + mock_invitation, + mock_account, + ): + """Ensure uppercase emails are normalized before lookup and revocation.""" + mock_get_invitation.return_value = mock_invitation + + with app.test_request_context( + "/activate", + method="POST", + json={ + "workspace_id": "workspace-123", + "email": "Invitee@Example.com", + "token": "valid_token", + "name": "John Doe", + "interface_language": "en-US", + "timezone": "UTC", + }, + ): + api = ActivateApi() + response = api.post() + + assert response["result"] == "success" + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index eb21920117..cb4fe40944 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -34,7 +34,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_allowed( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): @@ -67,7 +67,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_wrong_password_returns_error( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db ): @@ -100,7 +100,7 @@ class TestAuthenticationSecurity: @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invalid_email_with_registration_disabled( self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db ): diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/unit_tests/controllers/console/auth/test_email_register.py new file mode 100644 index 0000000000..724c80f18c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register.py @@ -0,0 +1,177 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.email_register import ( + EmailRegisterCheckApi, + EmailRegisterResetApi, + EmailRegisterSendEmailApi, +) +from services.account_service import AccountService + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestEmailRegisterSendEmailApi: + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") + @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") + @patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_send_email_normalizes_and_falls_back( + self, + mock_extract_ip, + mock_is_email_send_ip_limit, + mock_is_freeze, + mock_send_mail, + mock_get_account, + mock_session_cls, + app, + ): + mock_send_mail.return_value = "token-123" + mock_is_freeze.return_value = False + mock_account = MagicMock() + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = mock_account + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register/send-email", + method="POST", + json={"email": "Invitee@Example.com", "language": "en-US"}, + ): + response = EmailRegisterSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_is_freeze.assert_called_once_with("invitee@example.com") + mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) + mock_extract_ip.assert_called_once() + mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") + + +class TestEmailRegisterCheckApi: + @patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.generate_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit") + def test_validity_normalizes_email_before_checks( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register/validity", + method="POST", + json={"email": "User@Example.com", "code": "4321", "token": "token-123"}, + ): + response = EmailRegisterCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("user@example.com") + mock_generate_token.assert_called_once_with( + "user@example.com", code="4321", additional_data={"phase": "register"} + ) + mock_reset_rate.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke.assert_called_once_with("token-123") + + +class TestEmailRegisterResetApi: + @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") + @patch("controllers.console.auth.email_register.AccountService.login") + @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") + @patch("controllers.console.auth.email_register.Session") + @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") + @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") + @patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1") + def test_reset_creates_account_with_normalized_email( + self, + mock_extract_ip, + mock_get_data, + mock_revoke_token, + mock_get_account, + mock_session_cls, + mock_create_account, + mock_login, + mock_reset_login_rate, + app, + ): + mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} + mock_create_account.return_value = MagicMock() + token_pair = MagicMock() + token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} + mock_login.return_value = token_pair + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_get_account.return_value = None + + feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + with ( + patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), + ): + with app.test_request_context( + "/email-register", + method="POST", + json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"}, + ): + response = EmailRegisterResetApi().post() + + assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}} + mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!") + mock_reset_login_rate.assert_called_once_with("invitee@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_extract_ip.assert_called_once() + mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + + assert account is expected_account + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py new file mode 100644 index 0000000000..8403777dc9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py @@ -0,0 +1,176 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, +) +from services.account_service import AccountService + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1") + def test_send_normalizes_email( + self, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_account, + mock_session_cls, + app, + ): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_send_email.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) + controller_features = SimpleNamespace(is_allow_register=True) + with ( + patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), + patch( + "controllers.console.auth.forgot_password.FeatureService.get_system_features", + return_value=controller_features, + ), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_send_email.assert_called_once_with( + account=mock_account, + email="user@example.com", + language="zh-Hans", + is_allow_register=True, + ) + mock_is_ip_limit.assert_called_once_with("127.0.0.1") + mock_extract_ip.assert_called_once() + + +class TestForgotPasswordCheckApi: + @patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_check_normalizes_email( + self, + mock_rate_limit_check, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_rate_limit_check.return_value = False + mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} + mock_generate_token.return_value = (None, "new-token") + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with ( + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password/validity", + method="POST", + json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"} + mock_rate_limit_check.assert_called_once_with("admin@example.com") + mock_generate_token.assert_called_once_with( + "Admin@Example.com", + code="4321", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("admin@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + + +class TestForgotPasswordResetApi: + @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.console.auth.forgot_password.Session") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + def test_reset_fetches_account_with_original_email( + self, + mock_get_reset_data, + mock_revoke_token, + mock_get_account, + mock_session_cls, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} + mock_account = MagicMock() + mock_get_account.return_value = mock_account + + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + wraps_features = SimpleNamespace(enable_email_password_login=True) + with ( + patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), + patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + with app.test_request_context( + "/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_reset_data.assert_called_once_with("token-123") + mock_revoke_token.assert_called_once_with("token-123") + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_update_account.assert_called_once() + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + mock_session = MagicMock() + first_query = MagicMock() + first_query.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_query = MagicMock() + second_query.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_query, second_query] + + account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + + assert account is expected_account + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 3a2cf7bad7..560971206f 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -76,7 +76,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -120,7 +120,7 @@ class TestLoginApi: response = login_api.post() # Assert - mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!") + mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None) mock_login.assert_called_once() mock_reset_rate_limit.assert_called_once_with("test@example.com") assert response.json["result"] == "success" @@ -128,7 +128,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.AccountService.login") @@ -182,7 +182,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login rejection when rate limit is exceeded. @@ -230,7 +230,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") def test_login_fails_with_invalid_credentials( @@ -269,7 +269,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app @@ -298,7 +298,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") @patch("controllers.console.auth.login.TenantService.get_join_tenants") @patch("controllers.console.auth.login.FeatureService.get_system_features") @@ -343,7 +343,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") - @patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): """ Test login failure when invitation email doesn't match login email. @@ -371,6 +371,52 @@ class TestLoginApi: with pytest.raises(InvalidEmailError): login_api.post() + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) + @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") + @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") + @patch("controllers.console.auth.login.AccountService.authenticate") + @patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit") + @patch("controllers.console.auth.login.TenantService.get_join_tenants") + @patch("controllers.console.auth.login.AccountService.login") + @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") + def test_login_retries_with_lowercase_email( + self, + mock_reset_rate_limit, + mock_login_service, + mock_get_tenants, + mock_add_rate_limit, + mock_authenticate, + mock_get_invitation, + mock_is_rate_limit, + mock_db, + app, + mock_account, + mock_token_pair, + ): + """Test that login retries with lowercase email when uppercase lookup fails.""" + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_is_rate_limit.return_value = False + mock_get_invitation.return_value = None + mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account] + mock_get_tenants.return_value = [MagicMock()] + mock_login_service.return_value = mock_token_pair + + with app.test_request_context( + "/login", + method="POST", + json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")}, + ): + response = LoginApi().post() + + assert response.json["result"] == "success" + assert mock_authenticate.call_args_list == [ + (("Upper@Example.com", "ValidPass123!", None), {}), + (("upper@example.com", "ValidPass123!", None), {}), + ] + mock_add_rate_limit.assert_not_called() + mock_reset_rate_limit.assert_called_once_with("upper@example.com") + class TestLogoutApi: """Test cases for the LogoutApi endpoint.""" diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 3ddfcdb832..6345c2ab23 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -12,6 +12,7 @@ from controllers.console.auth.oauth import ( ) from libs.oauth import OAuthUserInfo from models.account import AccountStatus +from services.account_service import AccountService from services.errors.account import AccountRegisterError @@ -215,6 +216,34 @@ class TestOAuthCallback: assert status_code == 400 assert response["error"] == expected_error + @patch("controllers.console.auth.oauth.dify_config") + @patch("controllers.console.auth.oauth.get_oauth_providers") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.redirect") + def test_invitation_comparison_is_case_insensitive( + self, + mock_redirect, + mock_register_service, + mock_get_providers, + mock_config, + resource, + app, + oauth_setup, + ): + mock_config.CONSOLE_WEB_URL = "http://localhost:3000" + oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo( + id="123", name="Test User", email="User@Example.com" + ) + mock_get_providers.return_value = {"github": oauth_setup["provider"]} + mock_register_service.is_valid_invite_token.return_value = True + mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"} + + with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"): + resource.get("github") + + mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123") + mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123") + @pytest.mark.parametrize( ("account_status", "expected_redirect"), [ @@ -395,12 +424,12 @@ class TestAccountGeneration: account.name = "Test User" return account - @patch("controllers.console.auth.oauth.db") - @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.oauth.Session") - @patch("controllers.console.auth.oauth.select") + @patch("controllers.console.auth.oauth.Account") + @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account + self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account ): # Mock db.engine for Session creation mock_db.engine = MagicMock() @@ -410,15 +439,31 @@ class TestAccountGeneration: result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account mock_account_model.get_by_openid.assert_called_once_with("github", "123") + mock_get_account.assert_not_called() - # Test fallback to email + # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account + mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + + def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + mock_session = MagicMock() + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None + expected_account = MagicMock() + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] + + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + + assert result == expected_account + assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( ("allow_register", "existing_account", "should_create"), @@ -466,6 +511,35 @@ class TestAccountGeneration: mock_register_service.register.assert_called_once_with( email="test@example.com", name="Test User", password=None, open_id="123", provider="github" ) + else: + mock_register_service.register.assert_not_called() + + @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) + @patch("controllers.console.auth.oauth.FeatureService") + @patch("controllers.console.auth.oauth.RegisterService") + @patch("controllers.console.auth.oauth.AccountService") + @patch("controllers.console.auth.oauth.TenantService") + @patch("controllers.console.auth.oauth.db") + def test_should_register_with_lowercase_email( + self, + mock_db, + mock_tenant_service, + mock_account_service, + mock_register_service, + mock_feature_service, + mock_get_account, + app, + ): + user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_register_service.register.return_value = MagicMock() + + with app.test_request_context(headers={"Accept-Language": "en-US"}): + _generate_account("github", user_info) + + mock_register_service.register.assert_called_once_with( + email="upper@example.com", name="Test User", password=None, open_id="123", provider="github" + ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @patch("controllers.console.auth.oauth.TenantService") diff --git a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py b/api/tests/unit_tests/controllers/console/auth/test_password_reset.py index f584952a00..9488cf528e 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/unit_tests/controllers/console/auth/test_password_reset.py @@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import ( from controllers.console.error import AccountNotFound, EmailSendIpLimitError +@pytest.fixture(autouse=True) +def _mock_forgot_password_session(): + with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + mock_session_cls.return_value.__exit__.return_value = None + yield mock_session + + +@pytest.fixture(autouse=True) +def _mock_forgot_password_db(): + with patch("controllers.console.auth.forgot_password.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db + + class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi: return account @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") def test_send_reset_email_success( self, mock_get_features, mock_send_email, - mock_select, - mock_session, + mock_get_account, mock_is_ip_limit, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_is_ip_limit.return_value = False - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_send_email.return_value = "reset_token_123" mock_get_features.return_value.is_allow_register = True @@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi: ], ) @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.FeatureService.get_system_features") def test_send_reset_email_language_handling( self, mock_get_features, mock_send_email, - mock_select, - mock_session, + mock_get_account, mock_is_ip_limit, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_is_ip_limit.return_value = False - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_send_email.return_value = "token" mock_get_features.return_value.is_allow_register = True @@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi: assert response["email"] == "test@example.com" assert response["token"] == "new_token" mock_revoke_token.assert_called_once_with("old_token") + mock_generate_token.assert_called_once_with( + "test@example.com", code="123456", additional_data={"phase": "reset"} + ) mock_reset_rate_limit.assert_called_once_with("test@example.com") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") + @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + def test_verify_code_preserves_token_email_case( + self, + mock_reset_rate_limit, + mock_generate_token, + mock_revoke_token, + mock_get_data, + mock_is_rate_limit, + mock_db, + app, + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} + mock_generate_token.return_value = (None, "fresh-token") + + with app.test_request_context( + "/forgot-password/validity", + method="POST", + json={"email": "user@example.com", "code": "999888", "token": "upper_token"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"} + mock_generate_token.assert_called_once_with( + "User@Example.com", code="999888", additional_data={"phase": "reset"} + ) + mock_revoke_token.assert_called_once_with("upper_token") + mock_reset_rate_limit.assert_called_once_with("user@example.com") + @patch("controllers.console.wraps.db") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app): @@ -355,20 +395,16 @@ class TestForgotPasswordResetApi: return account @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, - mock_select, - mock_session, + mock_get_account, mock_revoke_token, mock_get_data, - mock_forgot_db, mock_wraps_db, app, mock_account, @@ -383,11 +419,8 @@ class TestForgotPasswordResetApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act @@ -475,13 +508,11 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") - @patch("controllers.console.auth.forgot_password.Session") - @patch("controllers.console.auth.forgot_password.select") + @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") def test_reset_password_account_not_found( - self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app + self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app ): """ Test password reset for non-existent account. @@ -491,11 +522,8 @@ class TestForgotPasswordResetApi: """ # Arrange mock_wraps_db.session.query.return_value.first.return_value = MagicMock() - mock_forgot_db.engine = MagicMock() mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"} - mock_session_instance = MagicMock() - mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None - mock_session.return_value.__enter__.return_value = mock_session_instance + mock_get_account.return_value = None # Act & Assert with app.test_request_context( diff --git a/api/tests/unit_tests/controllers/console/test_setup.py b/api/tests/unit_tests/controllers/console/test_setup.py new file mode 100644 index 0000000000..e7882dcd2b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_setup.py @@ -0,0 +1,39 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from controllers.console.setup import SetupApi + + +class TestSetupApi: + def test_post_lowercases_email_before_register(self): + """Ensure setup registration normalizes email casing.""" + payload = { + "email": "Admin@Example.com", + "name": "Admin User", + "password": "ValidPass123!", + "language": "en-US", + } + setup_api = SetupApi(api=None) + + mock_console_ns = SimpleNamespace(payload=payload) + + with ( + patch("controllers.console.setup.console_ns", mock_console_ns), + patch("controllers.console.setup.get_setup_status", return_value=False), + patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0), + patch("controllers.console.setup.get_init_validate_status", return_value=True), + patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"), + patch("controllers.console.setup.request", object()), + patch("controllers.console.setup.RegisterService.setup") as mock_register, + ): + response, status = setup_api.post() + + assert response == {"result": "success"} + assert status == 201 + mock_register.assert_called_once_with( + email="admin@example.com", + name=payload["name"], + password=payload["password"], + ip_address="127.0.0.1", + language=payload["language"], + ) diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py new file mode 100644 index 0000000000..9afc1c4166 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -0,0 +1,247 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, g + +from controllers.console.workspace.account import ( + AccountDeleteUpdateFeedbackApi, + ChangeEmailCheckApi, + ChangeEmailResetApi, + ChangeEmailSendEmailApi, + CheckEmailUnique, +) +from models import Account +from services.account_service import AccountService + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["RESTX_MASK_HEADER"] = "X-Fields" + app.login_manager = SimpleNamespace(_load_user=lambda: None) + return app + + +def _mock_wraps_db(mock_db): + mock_db.session.query.return_value.first.return_value = MagicMock() + + +def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: + tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") + account = Account(name=account_id, email=email) + account.email = email + account.id = account_id + account.status = "active" + account._current_tenant = tenant_obj + return account + + +def _set_logged_in_user(account: Account): + g._login_user = account + g._current_tenant = account.current_tenant + + +class TestChangeEmailSend: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_normalize_new_email_phase( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = {"email": "current@example.com"} + mock_send_email.return_value = "token-abc" + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailSendEmailApi().post() + + assert response == {"result": "success", "data": "token-abc"} + mock_send_email.assert_called_once_with( + account=None, + email="new@example.com", + old_email="current@example.com", + language="en-US", + phase="new_email", + ) + mock_extract_ip.assert_called_once() + mock_is_ip_limit.assert_called_once_with("127.0.0.1") + mock_csrf.assert_called_once() + + +class TestChangeEmailValidity: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_validate_with_normalized_email( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("user@example.com", "acc2") + mock_current_account.return_value = (mock_account, None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_generate_token.return_value = (None, "new-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_is_rate_limit.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + mock_generate_token.assert_called_once_with( + "user@example.com", code="1234", old_email="old@example.com", additional_data={} + ) + mock_reset_rate.assert_called_once_with("user@example.com") + mock_csrf.assert_called_once() + + +class TestChangeEmailReset: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_normalize_new_email_before_update( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_account_after_update = _build_account("new@example.com", "acc3-updated") + mock_update_account.return_value = mock_account_after_update + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "New@Example.com", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + ChangeEmailResetApi().post() + + mock_is_freeze.assert_called_once_with("new@example.com") + mock_check_unique.assert_called_once_with("new@example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_update_account.assert_called_once_with(current_user, email="new@example.com") + mock_send_notify.assert_called_once_with(email="new@example.com") + mock_csrf.assert_called_once() + + +class TestAccountDeletionFeedback: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") + def test_should_normalize_feedback_email(self, mock_update, mock_db, app): + _mock_wraps_db(mock_db) + with app.test_request_context( + "/account/delete/feedback", + method="POST", + json={"email": "User@Example.com", "feedback": "test"}, + ): + response = AccountDeleteUpdateFeedbackApi().post() + + assert response == {"result": "success"} + mock_update.assert_called_once_with("User@Example.com", "test") + + +class TestCheckEmailUnique: + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): + _mock_wraps_db(mock_db) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + + with app.test_request_context( + "/account/change-email/check-email-unique", + method="POST", + json={"email": "Case@Test.com"}, + ): + response = CheckEmailUnique().post() + + assert response == {"result": "success"} + mock_is_freeze.assert_called_once_with("case@test.com") + mock_check_unique.assert_called_once_with("case@test.com") + + +def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): + session = MagicMock() + first = MagicMock() + first.scalar_one_or_none.return_value = None + second = MagicMock() + expected_account = MagicMock() + second.scalar_one_or_none.return_value = expected_account + session.execute.side_effect = [first, second] + + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + + assert result is expected_account + assert session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py new file mode 100644 index 0000000000..368892b922 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -0,0 +1,82 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, g + +from controllers.console.workspace.members import MemberInviteEmailApi +from models.account import Account, TenantAccountRole + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + flask_app.login_manager = SimpleNamespace(_load_user=lambda: None) + return flask_app + + +def _mock_wraps_db(mock_db): + mock_db.session.query.return_value.first.return_value = MagicMock() + + +def _build_feature_flags(): + placeholder_quota = SimpleNamespace(limit=0, size=0) + workspace_members = SimpleNamespace(is_available=lambda count: True) + return SimpleNamespace( + billing=SimpleNamespace(enabled=False), + workspace_members=workspace_members, + members=placeholder_quota, + apps=placeholder_quota, + vector_space=placeholder_quota, + documents_upload_quota=placeholder_quota, + annotation_quota_limit=placeholder_quota, + ) + + +class TestMemberInviteEmailApi: + @patch("controllers.console.workspace.members.FeatureService.get_features") + @patch("controllers.console.workspace.members.RegisterService.invite_new_member") + @patch("controllers.console.workspace.members.current_account_with_tenant") + @patch("controllers.console.wraps.db") + @patch("libs.login.check_csrf_token", return_value=None) + def test_invite_normalizes_emails( + self, + mock_csrf, + mock_db, + mock_current_account, + mock_invite_member, + mock_get_features, + app, + ): + _mock_wraps_db(mock_db) + mock_get_features.return_value = _build_feature_flags() + mock_invite_member.return_value = "token-abc" + + tenant = SimpleNamespace(id="tenant-1", name="Test Tenant") + inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active") + mock_current_account.return_value = (inviter, tenant.id) + + with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"): + with app.test_request_context( + "/workspaces/current/members/invite-email", + method="POST", + json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"}, + ): + account = Account(name="tester", email="tester@example.com") + account._current_tenant = tenant + g._login_user = account + g._current_tenant = tenant + response, status_code = MemberInviteEmailApi().post() + + assert status_code == 201 + assert response["invitation_results"][0]["email"] == "user@example.com" + + assert mock_invite_member.call_count == 1 + call_args = mock_invite_member.call_args + assert call_args.kwargs["tenant"] == tenant + assert call_args.kwargs["email"] == "User@Example.com" + assert call_args.kwargs["language"] == "en-US" + assert call_args.kwargs["role"] == TenantAccountRole.EDITOR + assert call_args.kwargs["inviter"] == inviter + mock_csrf.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_forgot_password.py b/api/tests/unit_tests/controllers/web/test_forgot_password.py deleted file mode 100644 index d7c0d24f14..0000000000 --- a/api/tests/unit_tests/controllers/web/test_forgot_password.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Unit tests for controllers.web.forgot_password endpoints.""" - -from __future__ import annotations - -import base64 -import builtins -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest -from flask import Flask -from flask.views import MethodView - -# Ensure flask_restx.api finds MethodView during import. -if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView # type: ignore[attr-defined] - - -def _load_controller_module(): - """Import controllers.web.forgot_password using a stub package.""" - - import importlib - import importlib.util - import sys - from types import ModuleType - - parent_module_name = "controllers.web" - module_name = f"{parent_module_name}.forgot_password" - - if parent_module_name not in sys.modules: - from flask_restx import Namespace - - stub = ModuleType(parent_module_name) - stub.__file__ = "controllers/web/__init__.py" - stub.__path__ = ["controllers/web"] - stub.__package__ = "controllers" - stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True) - stub.web_ns = Namespace("web", description="Web API", path="/") - sys.modules[parent_module_name] = stub - - return importlib.import_module(module_name) - - -forgot_password_module = _load_controller_module() -ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi -ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi -ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi - - -@pytest.fixture -def app() -> Flask: - """Configure a minimal Flask app for request contexts.""" - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - -@pytest.fixture(autouse=True) -def _enable_web_endpoint_guards(): - """Stub enterprise and feature toggles used by route decorators.""" - - features = SimpleNamespace(enable_email_password_login=True) - with ( - patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True), - patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), - patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features), - ): - yield - - -@pytest.fixture(autouse=True) -def _mock_controller_db(): - """Replace controller-level db reference with a simple stub.""" - - fake_db = SimpleNamespace(engine=MagicMock(name="engine")) - fake_wraps_db = SimpleNamespace( - session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True)))) - ) - with ( - patch("controllers.web.forgot_password.db", fake_db), - patch("controllers.console.wraps.db", fake_wraps_db), - ): - yield fake_db - - -@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token") -@patch("controllers.web.forgot_password.Session") -@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) -@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10") -def test_send_reset_email_success( - mock_extract_ip: MagicMock, - mock_is_ip_limit: MagicMock, - mock_session: MagicMock, - mock_send_email: MagicMock, - app: Flask, -): - """POST /forgot-password returns token when email exists and limits allow.""" - - mock_account = MagicMock() - session_ctx = MagicMock() - mock_session.return_value.__enter__.return_value = session_ctx - session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account - - with app.test_request_context( - "/forgot-password", - method="POST", - json={"email": "user@example.com"}, - ): - response = ForgotPasswordSendEmailApi().post() - - assert response == {"result": "success", "data": "reset-token"} - mock_extract_ip.assert_called_once() - mock_is_ip_limit.assert_called_once_with("203.0.113.10") - mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US") - - -@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") -@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token")) -@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") -@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") -@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False) -def test_check_token_success( - mock_is_rate_limited: MagicMock, - mock_get_data: MagicMock, - mock_revoke: MagicMock, - mock_generate: MagicMock, - mock_reset_limit: MagicMock, - app: Flask, -): - """POST /forgot-password/validity validates the code and refreshes token.""" - - mock_get_data.return_value = {"email": "user@example.com", "code": "123456"} - - with app.test_request_context( - "/forgot-password/validity", - method="POST", - json={"email": "user@example.com", "code": "123456", "token": "old-token"}, - ): - response = ForgotPasswordCheckApi().post() - - assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} - mock_is_rate_limited.assert_called_once_with("user@example.com") - mock_get_data.assert_called_once_with("old-token") - mock_revoke.assert_called_once_with("old-token") - mock_generate.assert_called_once_with( - "user@example.com", - code="123456", - additional_data={"phase": "reset"}, - ) - mock_reset_limit.assert_called_once_with("user@example.com") - - -@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") -@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") -@patch("controllers.web.forgot_password.Session") -@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") -@patch("controllers.web.forgot_password.AccountService.get_reset_password_data") -def test_reset_password_success( - mock_get_data: MagicMock, - mock_revoke_token: MagicMock, - mock_session: MagicMock, - mock_token_bytes: MagicMock, - mock_hash_password: MagicMock, - app: Flask, -): - """POST /forgot-password/resets updates the stored password when token is valid.""" - - mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"} - account = MagicMock() - session_ctx = MagicMock() - mock_session.return_value.__enter__.return_value = session_ctx - session_ctx.execute.return_value.scalar_one_or_none.return_value = account - - with app.test_request_context( - "/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() - - assert response == {"result": "success"} - mock_get_data.assert_called_once_with("reset-token") - mock_revoke_token.assert_called_once_with("reset-token") - mock_token_bytes.assert_called_once_with(16) - mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef") - expected_password = base64.b64encode(b"hashed-value").decode() - assert account.password == expected_password - expected_salt = base64.b64encode(b"0123456789abcdef").decode() - assert account.password_salt == expected_salt - session_ctx.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_forgot_password.py b/api/tests/unit_tests/controllers/web/test_web_forgot_password.py new file mode 100644 index 0000000000..3d7c319947 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_forgot_password.py @@ -0,0 +1,226 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.forgot_password import ( + ForgotPasswordCheckApi, + ForgotPasswordResetApi, + ForgotPasswordSendEmailApi, +) + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture(autouse=True) +def _patch_wraps(): + wraps_features = SimpleNamespace(enable_email_password_login=True) + dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD") + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config", dify_settings), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + yield + + +class TestForgotPasswordSendEmailApi: + @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") + @patch("controllers.web.forgot_password.Session") + def test_should_normalize_email_before_sending( + self, + mock_session_cls, + mock_extract_ip, + mock_rate_limit, + mock_get_account, + mock_send_mail, + app, + ): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_send_mail.return_value = "token-123" + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") + mock_extract_ip.assert_called_once() + mock_rate_limit.assert_called_once_with("127.0.0.1") + + +class TestForgotPasswordCheckApi: + @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_should_normalize_email_for_validity_checks( + self, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} + mock_generate_token.return_value = (None, "new-token") + + with app.test_request_context( + "/web/forgot-password/validity", + method="POST", + json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} + mock_is_rate_limit.assert_called_once_with("user@example.com") + mock_add_rate.assert_not_called() + mock_revoke_token.assert_called_once_with("token-123") + mock_generate_token.assert_called_once_with( + "User@Example.com", + code="1234", + additional_data={"phase": "reset"}, + ) + mock_reset_rate.assert_called_once_with("user@example.com") + + @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") + @patch("controllers.web.forgot_password.AccountService.generate_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit") + def test_should_preserve_token_email_case( + self, + mock_is_rate_limit, + mock_get_data, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + app, + ): + mock_is_rate_limit.return_value = False + mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"} + mock_generate_token.return_value = (None, "fresh-token") + + with app.test_request_context( + "/web/forgot-password/validity", + method="POST", + json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"}, + ): + response = ForgotPasswordCheckApi().post() + + assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"} + mock_generate_token.assert_called_once_with( + "MixedCase@Example.com", + code="5678", + additional_data={"phase": "reset"}, + ) + mock_revoke_token.assert_called_once_with("token-upper") + mock_reset_rate.assert_called_once_with("mixedcase@example.com") + + +class TestForgotPasswordResetApi: + @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + def test_should_fetch_account_with_fallback( + self, + mock_get_reset_data, + mock_revoke_token, + mock_session_cls, + mock_get_account, + mock_update_account, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} + mock_account = MagicMock() + mock_get_account.return_value = mock_account + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_update_account.assert_called_once() + mock_revoke_token.assert_called_once_with("token-123") + + @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") + @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") + @patch("controllers.web.forgot_password.Session") + @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") + @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") + @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") + def test_should_update_password_and_commit( + self, + mock_get_account, + mock_get_reset_data, + mock_revoke_token, + mock_session_cls, + mock_token_bytes, + mock_hash_password, + app, + ): + mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} + account = MagicMock() + mock_get_account.return_value = account + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() + + assert response == {"result": "success"} + mock_get_reset_data.assert_called_once_with("reset-token") + mock_revoke_token.assert_called_once_with("reset-token") + mock_token_bytes.assert_called_once_with(16) + mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef") + expected_password = base64.b64encode(b"hashed-value").decode() + assert account.password == expected_password + expected_salt = base64.b64encode(b"0123456789abcdef").decode() + assert account.password_salt == expected_salt + mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py new file mode 100644 index 0000000000..e62993e8d5 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -0,0 +1,91 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi + + +def encode_code(code: str) -> str: + return base64.b64encode(code.encode("utf-8")).decode() + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture(autouse=True) +def _patch_wraps(): + wraps_features = SimpleNamespace(enable_email_password_login=True) + console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD") + web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True) + with ( + patch("controllers.console.wraps.db") as mock_db, + patch("controllers.console.wraps.dify_config", console_dify), + patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), + patch("controllers.web.login.dify_config", web_dify), + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + yield + + +class TestEmailCodeLoginSendEmailApi: + @patch("controllers.web.login.WebAppAuthService.send_email_code_login_email") + @patch("controllers.web.login.WebAppAuthService.get_user_through_email") + def test_should_fetch_account_with_original_email( + self, + mock_get_user, + mock_send_email, + app, + ): + mock_account = MagicMock() + mock_get_user.return_value = mock_account + mock_send_email.return_value = "token-123" + + with app.test_request_context( + "/web/email-code-login", + method="POST", + json={"email": "User@Example.com", "language": "en-US"}, + ): + response = EmailCodeLoginSendEmailApi().post() + + assert response == {"result": "success", "data": "token-123"} + mock_get_user.assert_called_once_with("User@Example.com") + mock_send_email.assert_called_once_with(account=mock_account, language="en-US") + + +class TestEmailCodeLoginApi: + @patch("controllers.web.login.AccountService.reset_login_error_rate_limit") + @patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token") + @patch("controllers.web.login.WebAppAuthService.get_user_through_email") + @patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token") + @patch("controllers.web.login.WebAppAuthService.get_email_code_login_data") + def test_should_normalize_email_before_validating( + self, + mock_get_token_data, + mock_revoke_token, + mock_get_user, + mock_login, + mock_reset_login_rate, + app, + ): + mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} + mock_get_user.return_value = MagicMock() + + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + response = EmailCodeLoginApi().post() + + assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}} + mock_get_user.assert_called_once_with("User@Example.com") + mock_revoke_token.assert_called_once_with("token-123") + mock_login.assert_called_once() + mock_reset_login_rate.assert_called_once_with("user@example.com") diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index e35ba74c56..8ae20f35d8 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from configs import dify_config -from models.account import Account +from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import ( AccountAlreadyInTenantError, @@ -1147,9 +1147,13 @@ class TestRegisterService: mock_session = MagicMock() mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account - with patch("services.account_service.Session") as mock_session_class: + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = None # Mock RegisterService.register mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( @@ -1182,9 +1186,59 @@ class TestRegisterService: email="newuser@example.com", name="newuser", language="en-US", - status="pending", + status=AccountStatus.PENDING, is_setup=True, ) + mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) + + def test_invite_new_member_normalizes_new_account_email( + self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies + ): + """Ensure inviting with mixed-case email normalizes before registering.""" + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") + mixed_email = "Invitee@Example.com" + + mock_session = MagicMock() + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = None + + mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending" + ) + with patch("services.account_service.RegisterService.register") as mock_register: + mock_register.return_value = mock_new_account + with ( + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant, + patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token, + ): + mock_generate_token.return_value = "invite-token-abc" + + RegisterService.invite_new_member( + tenant=mock_tenant, + email=mixed_email, + language="en-US", + role="normal", + inviter=mock_inviter, + ) + + mock_register.assert_called_once_with( + email="invitee@example.com", + name="invitee", + language="en-US", + status=AccountStatus.PENDING, + is_setup=True, + ) + mock_lookup.assert_called_once_with(mixed_email, session=mock_session) + mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account) @@ -1207,9 +1261,13 @@ class TestRegisterService: mock_session = MagicMock() mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account - with patch("services.account_service.Session") as mock_session_class: + with ( + patch("services.account_service.Session") as mock_session_class, + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + ): mock_session_class.return_value.__enter__.return_value = mock_session mock_session_class.return_value.__exit__.return_value = None + mock_lookup.return_value = mock_existing_account # Mock the db.session.query for TenantAccountJoin mock_db_query = MagicMock() @@ -1238,6 +1296,7 @@ class TestRegisterService: mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_task_dependencies.delay.assert_called_once() + mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): """Test inviting a member who is already in the tenant.""" @@ -1251,7 +1310,6 @@ class TestRegisterService: # Mock database queries query_results = { - ("Account", "email", "existing@example.com"): mock_existing_account, ( "TenantAccountJoin", "tenant_id", @@ -1261,7 +1319,11 @@ class TestRegisterService: ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) # Mock TenantService methods - with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission: + with ( + patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + ): + mock_lookup.return_value = mock_existing_account # Execute test and verify exception self._assert_exception_raised( AccountAlreadyInTenantError, @@ -1272,6 +1334,7 @@ class TestRegisterService: role="normal", inviter=mock_inviter, ) + mock_lookup.assert_called_once() def test_invite_new_member_no_inviter(self): """Test inviting a member without providing an inviter.""" @@ -1497,6 +1560,30 @@ class TestRegisterService: # Verify results assert result is None + def test_get_invitation_with_case_fallback_returns_initial_match(self): + """Fallback helper should return the initial invitation when present.""" + invitation = {"workspace_id": "tenant-456"} + with patch( + "services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation + ) as mock_get: + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123") + + def test_get_invitation_with_case_fallback_retries_with_lowercase(self): + """Fallback helper should retry with lowercase email when needed.""" + invitation = {"workspace_id": "tenant-456"} + with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get: + mock_get.side_effect = [None, invitation] + result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + + assert result == invitation + assert mock_get.call_args_list == [ + (("tenant-456", "User@Test.com", "token-123"),), + (("tenant-456", "user@test.com", "token-123"),), + ] + # ==================== Helper Method Tests ==================== def test_get_invitation_token_key(self):