diff --git a/api/commands.py b/api/commands.py index a8d89ac200..acb66ea96d 100644 --- a/api/commands.py +++ b/api/commands.py @@ -34,7 +34,7 @@ from libs.rsa import generate_key_pair from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile +from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.provider import Provider, ProviderModel from models.provider_ids import DatasourceProviderID, ToolProviderID @@ -62,8 +62,10 @@ def reset_password(email, new_password, password_confirm): if str(new_password).strip() != str(password_confirm).strip(): click.echo(click.style("Passwords do not match.", fg="red")) return + normalized_email = email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -84,7 +86,7 @@ def reset_password(email, new_password, password_confirm): base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(email) + AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -100,20 +102,22 @@ def reset_email(email, new_email, email_confirm): if str(new_email).strip() != str(email_confirm).strip(): click.echo(click.style("New emails do not match.", fg="red")) return + normalized_new_email = new_email.strip().lower() + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = session.query(Account).where(Account.email == email).one_or_none() + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) return try: - email_validate(new_email) + email_validate(normalized_new_email) except: click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account.email = new_email + account.email = normalized_new_email click.echo(click.style("Email updated successfully.", fg="green")) @@ -658,7 +662,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No return # Create account - email = email.strip() + email = email.strip().lower() if "@" not in email: click.echo(click.style("Invalid email address.", fg="red"))