mirror of https://github.com/langgenius/dify.git
chore: case insensitive email (#29978)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
0e33dfb5c2
commit
491e1fd6a4
|
|
@ -35,7 +35,7 @@ from libs.rsa import generate_key_pair
|
|||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
|
|
@ -64,8 +64,10 @@ def reset_password(email, new_password, password_confirm):
|
|||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
|
|
@ -86,7 +88,7 @@ def reset_password(email, new_password, password_confirm):
|
|||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
|
|
@ -102,20 +104,22 @@ def reset_email(email, new_email, email_confirm):
|
|||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
email_validate(normalized_new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
account.email = normalized_new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
|
|
@ -660,7 +664,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
|||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
email = email.strip().lower()
|
||||
|
||||
if "@" not in email:
|
||||
click.echo(click.style("Invalid email address.", fg="red"))
|
||||
|
|
|
|||
|
|
@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
|
|||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workspaceId = args.workspace_id
|
||||
reg_email = args.email
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
|
|
@ -100,11 +99,12 @@ class ActivateApi(Resource):
|
|||
def post(self):
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
|
||||
normalized_request_email = args.email.lower() if args.email else None
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
|
||||
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
|
||||
|
||||
account = invitation["account"]
|
||||
account.name = args.name
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
@email_register_enabled
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
|
|||
if args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
token = None
|
||||
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
|
|
@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
|
|||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
|
||||
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
|
||||
if is_email_register_error_rate_limit:
|
||||
raise EmailRegisterLimitError()
|
||||
|
||||
|
|
@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_email_register_error_rate_limit(args.email)
|
||||
AccountService.add_email_register_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
|
|||
user_email, code=args.code, additional_data={"phase": "register"}
|
||||
)
|
||||
|
||||
AccountService.reset_email_register_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_email_register_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
|
|
@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
|
|||
AccountService.revoke_email_register_token(args.token)
|
||||
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(email, args.password_confirm)
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email, password) -> Account | None:
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import secrets
|
|||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
|
|||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
|
@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
@email_password_login_enabled
|
||||
def post(self):
|
||||
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
)
|
||||
|
|
@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
|
|||
def post(self):
|
||||
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
|
|
@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
|
|||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args.code, additional_data={"phase": "reset"}
|
||||
token_email, code=args.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
|
|
@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
|
|||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
|
|
|||
|
|
@ -90,32 +90,38 @@ class LoginApi(Resource):
|
|||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
args = LoginPayload.model_validate(console_ns.payload)
|
||||
request_email = args.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
if args.invite_token:
|
||||
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
if invitation_data is None:
|
||||
invite_token = None
|
||||
|
||||
try:
|
||||
if invitation_data:
|
||||
data = invitation_data.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args.email:
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args.email, args.password, args.invite_token)
|
||||
else:
|
||||
account = AccountService.authenticate(args.email, args.password)
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args.email)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
|
|
@ -130,7 +136,7 @@ class LoginApi(Resource):
|
|||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
|
|||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
account=account,
|
||||
language=language,
|
||||
is_allow_register=FeatureService.get_system_features().is_allow_register,
|
||||
|
|
@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
|
|
@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args.email)
|
||||
account = _get_account_with_case_fallback(args.email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args.email, language=language)
|
||||
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
|
|
@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
|
|||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
|
||||
user_email = args.email
|
||||
original_email = args.email
|
||||
user_email = original_email.lower()
|
||||
language = args.language
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args.email:
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
|
|
@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
|
|||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
|
|
@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
|
|||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args.email)
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
|
@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
|
|||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
|
||||
def _get_account_with_case_fallback(email: str):
|
||||
account = AccountService.get_user_through_email(email)
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return AccountService.get_user_through_email(email.lower())
|
||||
|
||||
|
||||
def _authenticate_account_with_case_fallback(
|
||||
original_email: str, normalized_email: str, password: str, invite_token: str | None
|
||||
):
|
||||
try:
|
||||
return AccountService.authenticate(original_email, password, invite_token)
|
||||
except services.errors.account.AccountPasswordError:
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import logging
|
|||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
|
@ -118,7 +117,10 @@ class OAuthCallback(Resource):
|
|||
invitation = RegisterService.get_invitation_by_token(token=invite_token)
|
||||
if invitation:
|
||||
invitation_email = invitation.get("email", None)
|
||||
if invitation_email != user_info.email:
|
||||
invitation_email_normalized = (
|
||||
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
|
||||
)
|
||||
if invitation_email_normalized != user_info.email.lower():
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
|
@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
|||
|
||||
if not account:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
|
||||
|
||||
return account
|
||||
|
||||
|
|
@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
|||
tenant_was_created.send(new_tenant)
|
||||
|
||||
if not account:
|
||||
normalized_email = user_info.email.lower()
|
||||
oauth_new_user = True
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
"This email account has been deleted within the past "
|
||||
|
|
@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
|||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
|
|
|
|||
|
|
@ -84,10 +84,11 @@ class SetupApi(Resource):
|
|||
raise NotInitValidateError()
|
||||
|
||||
args = SetupRequestPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args.email,
|
||||
email=normalized_email,
|
||||
name=args.name,
|
||||
password=args.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from fields.member_fields import account_fields
|
|||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
|
@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
else:
|
||||
language = "en-US"
|
||||
account = None
|
||||
user_email = args.email
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
|
@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
|
|||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email != current_user.email:
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
raise InvalidEmailError()
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
user_email = account.email
|
||||
|
||||
token = AccountService.send_change_email_email(
|
||||
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
|
||||
account=account,
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
|
@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
|
|||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
user_email = args.email
|
||||
user_email = args.email.lower()
|
||||
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
|
||||
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
|
||||
if is_change_email_error_rate_limit:
|
||||
raise EmailChangeLimitError()
|
||||
|
||||
|
|
@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
AccountService.add_change_email_error_rate_limit(args.email)
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
|
|||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(args.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/account/change-email/reset")
|
||||
|
|
@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
|
|||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
|
||||
if AccountService.is_account_in_freeze(args.new_email):
|
||||
if AccountService.is_account_in_freeze(normalized_new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(args.new_email):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
|
|
@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
|
|||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email != old_email:
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=args.new_email,
|
||||
email=normalized_new_email,
|
||||
)
|
||||
|
||||
return updated_account
|
||||
|
|
@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
|
|||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = CheckEmailUniquePayload.model_validate(payload)
|
||||
if AccountService.is_account_in_freeze(args.email):
|
||||
normalized_email = args.email.lower()
|
||||
if AccountService.is_account_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(args.email):
|
||||
if not AccountService.check_email_unique(normalized_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
|
|
|||
|
|
@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
|
|||
raise WorkspaceMembersLimitExceeded()
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
normalized_invitee_email = invitee_email.lower()
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
encoded_invitee_email = parse.quote(normalized_invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"email": normalized_invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import secrets
|
|||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
|
@ -22,7 +21,7 @@ from controllers.web import web_ns
|
|||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
|
|
@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
def post(self):
|
||||
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
request_email = payload.email
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
|
@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
|
||||
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
|
@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
|
|||
def post(self):
|
||||
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
user_email = payload.email
|
||||
user_email = payload.email.lower()
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
|
|
@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
|
|||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if payload.code != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(payload.email)
|
||||
AccountService.add_forgot_password_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
|
|
@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
|
|||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
token_email, code=payload.code, additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(payload.email)
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
AccountService.reset_forgot_password_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
|
||||
|
||||
@web_ns.route("/forgot-password/resets")
|
||||
|
|
@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
|
|||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
|
|
|
|||
|
|
@ -197,25 +197,29 @@ class EmailCodeLoginApi(Resource):
|
|||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
user_email = args["email"].lower()
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||
account = WebAppAuthService.get_user_through_email(user_email)
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from hashlib import sha256
|
|||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
|
@ -748,6 +748,21 @@ class AccountService:
|
|||
cls.email_code_login_rate_limiter.increment_rate_limit(email)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
query_session = session or db.session
|
||||
account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
return TokenManager.get_token_data(token, "email_code_login")
|
||||
|
|
@ -1363,16 +1378,22 @@ class RegisterService:
|
|||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
|
||||
normalized_email = email.lower()
|
||||
|
||||
"""Invite new member"""
|
||||
with Session(db.engine) as session:
|
||||
account = session.query(Account).filter_by(email=email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
name = email.split("@")[0]
|
||||
name = normalized_email.split("@")[0]
|
||||
|
||||
account = cls.register(
|
||||
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
|
||||
email=normalized_email,
|
||||
name=name,
|
||||
language=language,
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
# Create new tenant member for invited tenant
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
|
|
@ -1394,7 +1415,7 @@ class RegisterService:
|
|||
# send email
|
||||
send_invite_member_mail_task.delay(
|
||||
language=language,
|
||||
to=email,
|
||||
to=account.email,
|
||||
token=token,
|
||||
inviter_name=inviter.name if inviter else "Dify",
|
||||
workspace_name=tenant.name,
|
||||
|
|
@ -1493,6 +1514,16 @@ class RegisterService:
|
|||
invitation: dict = json.loads(data)
|
||||
return invitation
|
||||
|
||||
@classmethod
|
||||
def get_invitation_with_case_fallback(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
) -> dict[str, Any] | None:
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
|
||||
if invitation or not email or email == email.lower():
|
||||
return invitation
|
||||
normalized_email = email.lower()
|
||||
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
|
||||
|
||||
|
||||
def _generate_refresh_token(length: int = 64):
|
||||
token = secrets.token_hex(length)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from libs.passport import PassportService
|
|||
from libs.password import compare_password
|
||||
from models import Account, AccountStatus
|
||||
from models.model import App, EndUser, Site
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
|
|
@ -32,7 +33,7 @@ class WebAppAuthService:
|
|||
@staticmethod
|
||||
def authenticate(email: str, password: str) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
|
|
@ -52,7 +53,7 @@ class WebAppAuthService:
|
|||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
account = db.session.query(Account).where(Account.email == email).first()
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class TestActivateCheckApi:
|
|||
"tenant": tenant,
|
||||
}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking valid invitation token.
|
||||
|
|
@ -66,7 +66,7 @@ class TestActivateCheckApi:
|
|||
assert response["data"]["workspace_id"] == "workspace-123"
|
||||
assert response["data"]["email"] == "invitee@example.com"
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_invalid_invitation_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test checking invalid invitation token.
|
||||
|
|
@ -88,7 +88,7 @@ class TestActivateCheckApi:
|
|||
# Assert
|
||||
assert response["is_valid"] is False
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without workspace ID.
|
||||
|
|
@ -109,7 +109,7 @@ class TestActivateCheckApi:
|
|||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
|
||||
"""
|
||||
Test checking token without email parameter.
|
||||
|
|
@ -130,6 +130,20 @@ class TestActivateCheckApi:
|
|||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
|
||||
"""Ensure token validation uses lowercase emails."""
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
with app.test_request_context(
|
||||
"/activate/check?workspace_id=workspace-123&email=Invitee@Example.com&token=valid_token"
|
||||
):
|
||||
api = ActivateCheckApi()
|
||||
response = api.get()
|
||||
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
|
||||
|
||||
class TestActivateApi:
|
||||
"""Test cases for account activation endpoint."""
|
||||
|
|
@ -212,7 +226,7 @@ class TestActivateApi:
|
|||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_activation_with_invalid_token(self, mock_get_invitation, app):
|
||||
"""
|
||||
Test account activation with invalid token.
|
||||
|
|
@ -241,7 +255,7 @@ class TestActivateApi:
|
|||
with pytest.raises(AlreadyActivateError):
|
||||
api.post()
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_sets_interface_theme(
|
||||
|
|
@ -290,7 +304,7 @@ class TestActivateApi:
|
|||
("es-ES", "Europe/Madrid"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_with_different_locales(
|
||||
|
|
@ -336,7 +350,7 @@ class TestActivateApi:
|
|||
assert mock_account.interface_language == language
|
||||
assert mock_account.timezone == timezone
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_returns_success_response(
|
||||
|
|
@ -376,7 +390,7 @@ class TestActivateApi:
|
|||
# Assert
|
||||
assert response == {"result": "success"}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_without_workspace_id(
|
||||
|
|
@ -415,3 +429,37 @@ class TestActivateApi:
|
|||
# Assert
|
||||
assert response["result"] == "success"
|
||||
mock_revoke_token.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.activate.RegisterService.revoke_token")
|
||||
@patch("controllers.console.auth.activate.db")
|
||||
def test_activation_normalizes_email_before_lookup(
|
||||
self,
|
||||
mock_db,
|
||||
mock_revoke_token,
|
||||
mock_get_invitation,
|
||||
app,
|
||||
mock_invitation,
|
||||
mock_account,
|
||||
):
|
||||
"""Ensure uppercase emails are normalized before lookup and revocation."""
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
with app.test_request_context(
|
||||
"/activate",
|
||||
method="POST",
|
||||
json={
|
||||
"workspace_id": "workspace-123",
|
||||
"email": "Invitee@Example.com",
|
||||
"token": "valid_token",
|
||||
"name": "John Doe",
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
):
|
||||
api = ActivateApi()
|
||||
response = api.post()
|
||||
|
||||
assert response["result"] == "success"
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestAuthenticationSecurity:
|
|||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invalid_email_with_registration_allowed(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
|
|
@ -67,7 +67,7 @@ class TestAuthenticationSecurity:
|
|||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_wrong_password_returns_error(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_db
|
||||
):
|
||||
|
|
@ -100,7 +100,7 @@ class TestAuthenticationSecurity:
|
|||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invalid_email_with_registration_disabled(
|
||||
self, mock_get_invitation, mock_add_rate_limit, mock_authenticate, mock_is_rate_limit, mock_features, mock_db
|
||||
):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.email_register import (
|
||||
EmailRegisterCheckApi,
|
||||
EmailRegisterResetApi,
|
||||
EmailRegisterSendEmailApi,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class TestEmailRegisterSendEmailApi:
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.send_email_register_email")
|
||||
@patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_send_email_normalizes_and_falls_back(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_email_send_ip_limit,
|
||||
mock_is_freeze,
|
||||
mock_send_mail,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_is_freeze.return_value = False
|
||||
mock_account = MagicMock()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/send-email",
|
||||
method="POST",
|
||||
json={"email": "Invitee@Example.com", "language": "en-US"},
|
||||
):
|
||||
response = EmailRegisterSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_is_freeze.assert_called_once_with("invitee@example.com")
|
||||
mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US")
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
||||
class TestEmailRegisterCheckApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.generate_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.add_email_register_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.AccountService.is_email_register_error_rate_limit")
|
||||
def test_validity_normalizes_email_before_checks(
|
||||
self,
|
||||
mock_rate_limit_check,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "4321", "token": "token-123"},
|
||||
):
|
||||
response = EmailRegisterCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_rate_limit_check.assert_called_once_with("user@example.com")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="4321", additional_data={"phase": "register"}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
class TestEmailRegisterResetApi:
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.Session")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_creates_account_with_normalized_email(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={"token": "token-123", "new_password": "ValidPass123!", "password_confirm": "ValidPass123!"},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session)
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
ForgotPasswordResetApi,
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.auth.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_send_normalizes_email(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
controller_features = SimpleNamespace(is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch(
|
||||
"controllers.console.auth.forgot_password.FeatureService.get_system_features",
|
||||
return_value=controller_features,
|
||||
),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=mock_account,
|
||||
email="user@example.com",
|
||||
language="zh-Hans",
|
||||
is_allow_register=True,
|
||||
)
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_check_normalizes_email(
|
||||
self,
|
||||
mock_rate_limit_check,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_rate_limit_check.return_value = False
|
||||
mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "ADMIN@Example.com", "code": "4321", "token": "token-123"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "admin@example.com", "token": "new-token"}
|
||||
mock_rate_limit_check.assert_called_once_with("admin@example.com")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"Admin@Example.com",
|
||||
code="4321",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("admin@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_fetches_account_with_original_email(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_session_cls,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")),
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("token-123")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
mock_session = MagicMock()
|
||||
first_query = MagicMock()
|
||||
first_query.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_query = MagicMock()
|
||||
second_query.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_query, second_query]
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session)
|
||||
|
||||
assert account is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
|
@ -76,7 +76,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
|
|
@ -120,7 +120,7 @@ class TestLoginApi:
|
|||
response = login_api.post()
|
||||
|
||||
# Assert
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!")
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
assert response.json["result"] == "success"
|
||||
|
|
@ -128,7 +128,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
|
|
@ -182,7 +182,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login rejection when rate limit is exceeded.
|
||||
|
|
@ -230,7 +230,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
def test_login_fails_with_invalid_credentials(
|
||||
|
|
@ -269,7 +269,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
def test_login_fails_for_banned_account(
|
||||
self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app
|
||||
|
|
@ -298,7 +298,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.FeatureService.get_system_features")
|
||||
|
|
@ -343,7 +343,7 @@ class TestLoginApi:
|
|||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_if_token_valid")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app):
|
||||
"""
|
||||
Test login failure when invitation email doesn't match login email.
|
||||
|
|
@ -371,6 +371,52 @@ class TestLoginApi:
|
|||
with pytest.raises(InvalidEmailError):
|
||||
login_api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False)
|
||||
@patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback")
|
||||
@patch("controllers.console.auth.login.AccountService.authenticate")
|
||||
@patch("controllers.console.auth.login.AccountService.add_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.login.TenantService.get_join_tenants")
|
||||
@patch("controllers.console.auth.login.AccountService.login")
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_login_retries_with_lowercase_email(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login_service,
|
||||
mock_get_tenants,
|
||||
mock_add_rate_limit,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
mock_account,
|
||||
mock_token_pair,
|
||||
):
|
||||
"""Test that login retries with lowercase email when uppercase lookup fails."""
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_invitation.return_value = None
|
||||
mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account]
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
mock_login_service.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context(
|
||||
"/login",
|
||||
method="POST",
|
||||
json={"email": "Upper@Example.com", "password": encode_password("ValidPass123!")},
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.json["result"] == "success"
|
||||
assert mock_authenticate.call_args_list == [
|
||||
(("Upper@Example.com", "ValidPass123!", None), {}),
|
||||
(("upper@example.com", "ValidPass123!", None), {}),
|
||||
]
|
||||
mock_add_rate_limit.assert_not_called()
|
||||
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
"""Test cases for the LogoutApi endpoint."""
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from controllers.console.auth.oauth import (
|
|||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
|
|
@ -215,6 +216,34 @@ class TestOAuthCallback:
|
|||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_invitation_comparison_is_case_insensitive(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_register_service,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
oauth_setup["provider"].get_user_info.return_value = OAuthUserInfo(
|
||||
id="123", name="Test User", email="User@Example.com"
|
||||
)
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_register_service.is_valid_invite_token.return_value = True
|
||||
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
|
||||
resource.get("github")
|
||||
|
||||
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000/signin/invite-settings?invite_token=invite123")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
|
|
@ -395,12 +424,12 @@ class TestAccountGeneration:
|
|||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.oauth.Session")
|
||||
@patch("controllers.console.auth.oauth.select")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
|
||||
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for Session creation
|
||||
mock_db.engine = MagicMock()
|
||||
|
|
@ -410,15 +439,31 @@ class TestAccountGeneration:
|
|||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||
mock_get_account.assert_not_called()
|
||||
|
||||
# Test fallback to email
|
||||
# Test fallback to email lookup
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
|
||||
mock_session = MagicMock()
|
||||
first_result = MagicMock()
|
||||
first_result.scalar_one_or_none.return_value = None
|
||||
expected_account = MagicMock()
|
||||
second_result = MagicMock()
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||
|
||||
assert result == expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_register", "existing_account", "should_create"),
|
||||
|
|
@ -466,6 +511,35 @@ class TestAccountGeneration:
|
|||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_register_with_lowercase_email(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
):
|
||||
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US"}):
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,22 @@ from controllers.console.auth.forgot_password import (
|
|||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_session():
|
||||
with patch("controllers.console.auth.forgot_password.Session") as mock_session_cls:
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
mock_session_cls.return_value.__exit__.return_value = None
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_forgot_password_db():
|
||||
with patch("controllers.console.auth.forgot_password.db") as mock_db:
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
|
|
@ -47,20 +63,16 @@ class TestForgotPasswordSendEmailApi:
|
|||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_success(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
|
|
@ -75,11 +87,8 @@ class TestForgotPasswordSendEmailApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "reset_token_123"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
|
|
@ -125,20 +134,16 @@ class TestForgotPasswordSendEmailApi:
|
|||
],
|
||||
)
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.console.auth.forgot_password.FeatureService.get_system_features")
|
||||
def test_send_reset_email_language_handling(
|
||||
self,
|
||||
mock_get_features,
|
||||
mock_send_email,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_is_ip_limit,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
|
|
@ -154,11 +159,8 @@ class TestForgotPasswordSendEmailApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_is_ip_limit.return_value = False
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_email.return_value = "token"
|
||||
mock_get_features.return_value.is_allow_register = True
|
||||
|
||||
|
|
@ -229,8 +231,46 @@ class TestForgotPasswordCheckApi:
|
|||
assert response["email"] == "test@example.com"
|
||||
assert response["token"] == "new_token"
|
||||
mock_revoke_token.assert_called_once_with("old_token")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"test@example.com", code="123456", additional_data={"phase": "reset"}
|
||||
)
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
def test_verify_code_preserves_token_email_case(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_generate_token,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_is_rate_limit,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "999888", "token": "upper_token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "fresh-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"User@Example.com", code="999888", additional_data={"phase": "reset"}
|
||||
)
|
||||
mock_revoke_token.assert_called_once_with("upper_token")
|
||||
mock_reset_rate_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_verify_code_rate_limited(self, mock_is_rate_limit, mock_db, app):
|
||||
|
|
@ -355,20 +395,16 @@ class TestForgotPasswordResetApi:
|
|||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants")
|
||||
def test_reset_password_success(
|
||||
self,
|
||||
mock_get_tenants,
|
||||
mock_select,
|
||||
mock_session,
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_data,
|
||||
mock_forgot_db,
|
||||
mock_wraps_db,
|
||||
app,
|
||||
mock_account,
|
||||
|
|
@ -383,11 +419,8 @@ class TestForgotPasswordResetApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_get_tenants.return_value = [MagicMock()]
|
||||
|
||||
# Act
|
||||
|
|
@ -475,13 +508,11 @@ class TestForgotPasswordResetApi:
|
|||
api.post()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.forgot_password.db")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.console.auth.forgot_password.Session")
|
||||
@patch("controllers.console.auth.forgot_password.select")
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_reset_password_account_not_found(
|
||||
self, mock_select, mock_session, mock_revoke_token, mock_get_data, mock_forgot_db, mock_wraps_db, app
|
||||
self, mock_get_account, mock_revoke_token, mock_get_data, mock_wraps_db, app
|
||||
):
|
||||
"""
|
||||
Test password reset for non-existent account.
|
||||
|
|
@ -491,11 +522,8 @@ class TestForgotPasswordResetApi:
|
|||
"""
|
||||
# Arrange
|
||||
mock_wraps_db.session.query.return_value.first.return_value = MagicMock()
|
||||
mock_forgot_db.engine = MagicMock()
|
||||
mock_get_data.return_value = {"email": "nonexistent@example.com", "phase": "reset"}
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = None
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_get_account.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.console.setup import SetupApi
|
||||
|
||||
|
||||
class TestSetupApi:
|
||||
def test_post_lowercases_email_before_register(self):
|
||||
"""Ensure setup registration normalizes email casing."""
|
||||
payload = {
|
||||
"email": "Admin@Example.com",
|
||||
"name": "Admin User",
|
||||
"password": "ValidPass123!",
|
||||
"language": "en-US",
|
||||
}
|
||||
setup_api = SetupApi(api=None)
|
||||
|
||||
mock_console_ns = SimpleNamespace(payload=payload)
|
||||
|
||||
with (
|
||||
patch("controllers.console.setup.console_ns", mock_console_ns),
|
||||
patch("controllers.console.setup.get_setup_status", return_value=False),
|
||||
patch("controllers.console.setup.TenantService.get_tenant_count", return_value=0),
|
||||
patch("controllers.console.setup.get_init_validate_status", return_value=True),
|
||||
patch("controllers.console.setup.extract_remote_ip", return_value="127.0.0.1"),
|
||||
patch("controllers.console.setup.request", object()),
|
||||
patch("controllers.console.setup.RegisterService.setup") as mock_register,
|
||||
):
|
||||
response, status = setup_api.post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
assert status == 201
|
||||
mock_register.assert_called_once_with(
|
||||
email="admin@example.com",
|
||||
name=payload["name"],
|
||||
password=payload["password"],
|
||||
ip_address="127.0.0.1",
|
||||
language=payload["language"],
|
||||
)
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
ChangeEmailResetApi,
|
||||
ChangeEmailSendEmailApi,
|
||||
CheckEmailUnique,
|
||||
)
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
return app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
account = Account(name=account_id, email=email)
|
||||
account.email = email
|
||||
account.id = account_id
|
||||
account.status = "active"
|
||||
account._current_tenant = tenant_obj
|
||||
return account
|
||||
|
||||
|
||||
def _set_logged_in_user(account: Account):
|
||||
g._login_user = account
|
||||
g._current_tenant = account.current_tenant
|
||||
|
||||
|
||||
class TestChangeEmailSend:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-abc"}
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=None,
|
||||
email="new@example.com",
|
||||
old_email="current@example.com",
|
||||
language="en-US",
|
||||
phase="new_email",
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_validate_with_normalized_email(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_before_update(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "New@Example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_is_freeze.assert_called_once_with("new@example.com")
|
||||
mock_check_unique.assert_called_once_with("new@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "feedback": "test"},
|
||||
):
|
||||
response = AccountDeleteUpdateFeedbackApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_update.assert_called_once_with("User@Example.com", "test")
|
||||
|
||||
|
||||
class TestCheckEmailUnique:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/check-email-unique",
|
||||
method="POST",
|
||||
json={"email": "Case@Test.com"},
|
||||
):
|
||||
response = CheckEmailUnique().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_is_freeze.assert_called_once_with("case@test.com")
|
||||
mock_check_unique.assert_called_once_with("case@test.com")
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
session = MagicMock()
|
||||
first = MagicMock()
|
||||
first.scalar_one_or_none.return_value = None
|
||||
second = MagicMock()
|
||||
expected_account = MagicMock()
|
||||
second.scalar_one_or_none.return_value = expected_account
|
||||
session.execute.side_effect = [first, second]
|
||||
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session)
|
||||
|
||||
assert result is expected_account
|
||||
assert session.execute.call_count == 2
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.workspace.members import MemberInviteEmailApi
|
||||
from models.account import Account, TenantAccountRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
flask_app.login_manager = SimpleNamespace(_load_user=lambda: None)
|
||||
return flask_app
|
||||
|
||||
|
||||
def _mock_wraps_db(mock_db):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
|
||||
|
||||
def _build_feature_flags():
|
||||
placeholder_quota = SimpleNamespace(limit=0, size=0)
|
||||
workspace_members = SimpleNamespace(is_available=lambda count: True)
|
||||
return SimpleNamespace(
|
||||
billing=SimpleNamespace(enabled=False),
|
||||
workspace_members=workspace_members,
|
||||
members=placeholder_quota,
|
||||
apps=placeholder_quota,
|
||||
vector_space=placeholder_quota,
|
||||
documents_upload_quota=placeholder_quota,
|
||||
annotation_quota_limit=placeholder_quota,
|
||||
)
|
||||
|
||||
|
||||
class TestMemberInviteEmailApi:
|
||||
@patch("controllers.console.workspace.members.FeatureService.get_features")
|
||||
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
|
||||
@patch("controllers.console.workspace.members.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
def test_invite_normalizes_emails(
|
||||
self,
|
||||
mock_csrf,
|
||||
mock_db,
|
||||
mock_current_account,
|
||||
mock_invite_member,
|
||||
mock_get_features,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
tenant = SimpleNamespace(id="tenant-1", name="Test Tenant")
|
||||
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
|
||||
mock_current_account.return_value = (inviter, tenant.id)
|
||||
|
||||
with patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"):
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/members/invite-email",
|
||||
method="POST",
|
||||
json={"emails": ["User@Example.com"], "role": TenantAccountRole.EDITOR.value, "language": "en-US"},
|
||||
):
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account._current_tenant = tenant
|
||||
g._login_user = account
|
||||
g._current_tenant = tenant
|
||||
response, status_code = MemberInviteEmailApi().post()
|
||||
|
||||
assert status_code == 201
|
||||
assert response["invitation_results"][0]["email"] == "user@example.com"
|
||||
|
||||
assert mock_invite_member.call_count == 1
|
||||
call_args = mock_invite_member.call_args
|
||||
assert call_args.kwargs["tenant"] == tenant
|
||||
assert call_args.kwargs["email"] == "User@Example.com"
|
||||
assert call_args.kwargs["language"] == "en-US"
|
||||
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
|
||||
assert call_args.kwargs["inviter"] == inviter
|
||||
mock_csrf.assert_called_once()
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
"""Unit tests for controllers.web.forgot_password endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_controller_module():
|
||||
"""Import controllers.web.forgot_password using a stub package."""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
parent_module_name = "controllers.web"
|
||||
module_name = f"{parent_module_name}.forgot_password"
|
||||
|
||||
if parent_module_name not in sys.modules:
|
||||
from flask_restx import Namespace
|
||||
|
||||
stub = ModuleType(parent_module_name)
|
||||
stub.__file__ = "controllers/web/__init__.py"
|
||||
stub.__path__ = ["controllers/web"]
|
||||
stub.__package__ = "controllers"
|
||||
stub.__spec__ = importlib.util.spec_from_loader(parent_module_name, loader=None, is_package=True)
|
||||
stub.web_ns = Namespace("web", description="Web API", path="/")
|
||||
sys.modules[parent_module_name] = stub
|
||||
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
forgot_password_module = _load_controller_module()
|
||||
ForgotPasswordCheckApi = forgot_password_module.ForgotPasswordCheckApi
|
||||
ForgotPasswordResetApi = forgot_password_module.ForgotPasswordResetApi
|
||||
ForgotPasswordSendEmailApi = forgot_password_module.ForgotPasswordSendEmailApi
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_web_endpoint_guards():
|
||||
"""Stub enterprise and feature toggles used by route decorators."""
|
||||
|
||||
features = SimpleNamespace(enable_email_password_login=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=features),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_controller_db():
|
||||
"""Replace controller-level db reference with a simple stub."""
|
||||
|
||||
fake_db = SimpleNamespace(engine=MagicMock(name="engine"))
|
||||
fake_wraps_db = SimpleNamespace(
|
||||
session=MagicMock(query=MagicMock(return_value=MagicMock(first=MagicMock(return_value=True))))
|
||||
)
|
||||
with (
|
||||
patch("controllers.web.forgot_password.db", fake_db),
|
||||
patch("controllers.console.wraps.db", fake_wraps_db),
|
||||
):
|
||||
yield fake_db
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email", return_value="reset-token")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="203.0.113.10")
|
||||
def test_send_reset_email_success(
|
||||
mock_extract_ip: MagicMock,
|
||||
mock_is_ip_limit: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_send_email: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password returns token when email exists and limits allow."""
|
||||
|
||||
mock_account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "user@example.com"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "reset-token"}
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("203.0.113.10")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, email="user@example.com", language="en-US")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token", return_value=({}, "new-token"))
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit", return_value=False)
|
||||
def test_check_token_success(
|
||||
mock_is_rate_limited: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke: MagicMock,
|
||||
mock_generate: MagicMock,
|
||||
mock_reset_limit: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/validity validates the code and refreshes token."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "123456"}
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "code": "123456", "token": "old-token"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limited.assert_called_once_with("user@example.com")
|
||||
mock_get_data.assert_called_once_with("old-token")
|
||||
mock_revoke.assert_called_once_with("old-token")
|
||||
mock_generate.assert_called_once_with(
|
||||
"user@example.com",
|
||||
code="123456",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_limit.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_reset_password_success(
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
mock_token_bytes: MagicMock,
|
||||
mock_hash_password: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""POST /forgot-password/resets updates the stored password when token is valid."""
|
||||
|
||||
mock_get_data.return_value = {"email": "user@example.com", "phase": "reset"}
|
||||
account = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = session_ctx
|
||||
session_ctx.execute.return_value.scalar_one_or_none.return_value = account
|
||||
|
||||
with app.test_request_context(
|
||||
"/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
session_ctx.commit.assert_called_once()
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
ForgotPasswordResetApi,
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
dify_settings = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
|
||||
with (
|
||||
patch("controllers.console.wraps.db") as mock_db,
|
||||
patch("controllers.console.wraps.dify_config", dify_settings),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
def test_should_normalize_email_before_sending(
|
||||
self,
|
||||
mock_session_cls,
|
||||
mock_extract_ip,
|
||||
mock_rate_limit,
|
||||
mock_get_account,
|
||||
mock_send_mail,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_send_mail.return_value = "token-123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "zh-Hans"},
|
||||
):
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.add_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_should_normalize_email_for_validity_checks(
|
||||
self,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("user@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"User@Example.com",
|
||||
code="1234",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
def test_should_preserve_token_email_case(
|
||||
self,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
app,
|
||||
):
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"}
|
||||
mock_generate_token.return_value = (None, "fresh-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/validity",
|
||||
method="POST",
|
||||
json={"email": "mixedcase@example.com", "code": "5678", "token": "token-upper"},
|
||||
):
|
||||
response = ForgotPasswordCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "mixedcase@example.com", "token": "fresh-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"MixedCase@Example.com",
|
||||
code="5678",
|
||||
additional_data={"phase": "reset"},
|
||||
)
|
||||
mock_revoke_token.assert_called_once_with("token-upper")
|
||||
mock_reset_rate.assert_called_once_with("mixedcase@example.com")
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
def test_should_fetch_account_with_fallback(
|
||||
self,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_get_account,
|
||||
mock_update_account,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"}
|
||||
mock_account = MagicMock()
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com", session=mock_session)
|
||||
mock_update_account.assert_called_once()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
@patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value")
|
||||
@patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_reset_password_data")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
def test_should_update_password_and_commit(
|
||||
self,
|
||||
mock_get_account,
|
||||
mock_get_reset_data,
|
||||
mock_revoke_token,
|
||||
mock_session_cls,
|
||||
mock_token_bytes,
|
||||
mock_hash_password,
|
||||
app,
|
||||
):
|
||||
mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"}
|
||||
account = MagicMock()
|
||||
mock_get_account.return_value = account
|
||||
mock_session = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")):
|
||||
with app.test_request_context(
|
||||
"/web/forgot-password/resets",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "reset-token",
|
||||
"new_password": "StrongPass123!",
|
||||
"password_confirm": "StrongPass123!",
|
||||
},
|
||||
):
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_reset_data.assert_called_once_with("reset-token")
|
||||
mock_revoke_token.assert_called_once_with("reset-token")
|
||||
mock_token_bytes.assert_called_once_with(16)
|
||||
mock_hash_password.assert_called_once_with("StrongPass123!", b"0123456789abcdef")
|
||||
expected_password = base64.b64encode(b"hashed-value").decode()
|
||||
assert account.password == expected_password
|
||||
expected_salt = base64.b64encode(b"0123456789abcdef").decode()
|
||||
assert account.password_salt == expected_salt
|
||||
mock_session.commit.assert_called_once()
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
return base64.b64encode(code.encode("utf-8")).decode()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
console_dify = SimpleNamespace(ENTERPRISE_ENABLED=True, EDITION="CLOUD")
|
||||
web_dify = SimpleNamespace(ENTERPRISE_ENABLED=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.db") as mock_db,
|
||||
patch("controllers.console.wraps.dify_config", console_dify),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features),
|
||||
patch("controllers.web.login.dify_config", web_dify),
|
||||
):
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendEmailApi:
|
||||
@patch("controllers.web.login.WebAppAuthService.send_email_code_login_email")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
|
||||
def test_should_fetch_account_with_original_email(
|
||||
self,
|
||||
mock_get_user,
|
||||
mock_send_email,
|
||||
app,
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_get_user.return_value = mock_account
|
||||
mock_send_email.return_value = "token-123"
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "language": "en-US"},
|
||||
):
|
||||
response = EmailCodeLoginSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_user.assert_called_once_with("User@Example.com")
|
||||
mock_send_email.assert_called_once_with(account=mock_account, language="en-US")
|
||||
|
||||
|
||||
class TestEmailCodeLoginApi:
|
||||
@patch("controllers.web.login.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.web.login.WebAppAuthService.login", return_value="new-access-token")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_user_through_email")
|
||||
@patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token")
|
||||
@patch("controllers.web.login.WebAppAuthService.get_email_code_login_data")
|
||||
def test_should_normalize_email_before_validating(
|
||||
self,
|
||||
mock_get_token_data,
|
||||
mock_revoke_token,
|
||||
mock_get_user,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app,
|
||||
):
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
mock_get_user.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/email-code-login/validity",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"},
|
||||
):
|
||||
response = EmailCodeLoginApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success", "data": {"access_token": "new-access-token"}}
|
||||
mock_get_user.assert_called_once_with("User@Example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_login_rate.assert_called_once_with("user@example.com")
|
||||
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from models.account import Account
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
|
|
@ -1147,9 +1147,13 @@ class TestRegisterService:
|
|||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
||||
|
||||
with patch("services.account_service.Session") as mock_session_class:
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
# Mock RegisterService.register
|
||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
|
|
@ -1182,9 +1186,59 @@ class TestRegisterService:
|
|||
email="newuser@example.com",
|
||||
name="newuser",
|
||||
language="en-US",
|
||||
status="pending",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session)
|
||||
|
||||
def test_invite_new_member_normalizes_new_account_email(
|
||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||
):
|
||||
"""Ensure inviting with mixed-case email normalizes before registering."""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-456"
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
mixed_email = "Invitee@Example.com"
|
||||
|
||||
mock_session = MagicMock()
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="new-user-789", email="invitee@example.com", name="invitee", status="pending"
|
||||
)
|
||||
with patch("services.account_service.RegisterService.register") as mock_register:
|
||||
mock_register.return_value = mock_new_account
|
||||
with (
|
||||
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
|
||||
patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
|
||||
patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant,
|
||||
patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
|
||||
):
|
||||
mock_generate_token.return_value = "invite-token-abc"
|
||||
|
||||
RegisterService.invite_new_member(
|
||||
tenant=mock_tenant,
|
||||
email=mixed_email,
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
)
|
||||
|
||||
mock_register.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
name="invitee",
|
||||
language="en-US",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mixed_email, session=mock_session)
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
|
||||
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account)
|
||||
|
|
@ -1207,9 +1261,13 @@ class TestRegisterService:
|
|||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
||||
|
||||
with patch("services.account_service.Session") as mock_session_class:
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
|
||||
# Mock the db.session.query for TenantAccountJoin
|
||||
mock_db_query = MagicMock()
|
||||
|
|
@ -1238,6 +1296,7 @@ class TestRegisterService:
|
|||
mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
|
||||
mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
|
||||
mock_task_dependencies.delay.assert_called_once()
|
||||
mock_lookup.assert_called_once_with("existing@example.com", session=mock_session)
|
||||
|
||||
def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
|
||||
"""Test inviting a member who is already in the tenant."""
|
||||
|
|
@ -1251,7 +1310,6 @@ class TestRegisterService:
|
|||
|
||||
# Mock database queries
|
||||
query_results = {
|
||||
("Account", "email", "existing@example.com"): mock_existing_account,
|
||||
(
|
||||
"TenantAccountJoin",
|
||||
"tenant_id",
|
||||
|
|
@ -1261,7 +1319,11 @@ class TestRegisterService:
|
|||
ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
|
||||
|
||||
# Mock TenantService methods
|
||||
with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission:
|
||||
with (
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
|
||||
):
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
AccountAlreadyInTenantError,
|
||||
|
|
@ -1272,6 +1334,7 @@ class TestRegisterService:
|
|||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
)
|
||||
mock_lookup.assert_called_once()
|
||||
|
||||
def test_invite_new_member_no_inviter(self):
|
||||
"""Test inviting a member without providing an inviter."""
|
||||
|
|
@ -1497,6 +1560,30 @@ class TestRegisterService:
|
|||
# Verify results
|
||||
assert result is None
|
||||
|
||||
def test_get_invitation_with_case_fallback_returns_initial_match(self):
|
||||
"""Fallback helper should return the initial invitation when present."""
|
||||
invitation = {"workspace_id": "tenant-456"}
|
||||
with patch(
|
||||
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
|
||||
) as mock_get:
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
assert result == invitation
|
||||
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
|
||||
"""Fallback helper should retry with lowercase email when needed."""
|
||||
invitation = {"workspace_id": "tenant-456"}
|
||||
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
|
||||
mock_get.side_effect = [None, invitation]
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
|
||||
assert result == invitation
|
||||
assert mock_get.call_args_list == [
|
||||
(("tenant-456", "User@Test.com", "token-123"),),
|
||||
(("tenant-456", "user@test.com", "token-123"),),
|
||||
]
|
||||
|
||||
# ==================== Helper Method Tests ====================
|
||||
|
||||
def test_get_invitation_token_key(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue