diff --git a/api/.env.example b/api/.env.example index edbb684cc7..22287e6b34 100644 --- a/api/.env.example +++ b/api/.env.example @@ -277,3 +277,8 @@ POSITION_TOOL_EXCLUDES= POSITION_PROVIDER_PINS= POSITION_PROVIDER_INCLUDES= POSITION_PROVIDER_EXCLUDES= + +# Login +EMAIL_CODE_LOGIN_TOKEN_EXPIRY_HOURS=1/12 +ALLOW_REGISTER=true +ALLOW_CREATE_WORKSPACE=true \ No newline at end of file diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f2efa52de3..6b62951b9e 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,6 +1,15 @@ from typing import Optional -from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field +from pydantic import ( + AliasChoices, + Field, + HttpUrl, + NegativeInt, + NonNegativeInt, + PositiveFloat, + PositiveInt, + computed_field, +) from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -602,6 +611,21 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class LoginConfig(BaseSettings): + EMAIL_CODE_LOGIN_TOKEN_EXPIRY_HOURS: PositiveFloat = Field( + description="expiry time in hours for email code login token", + default=1 / 12, + ) + ALLOW_REGISTER: bool = Field( + description="whether to enable register", + default=True, + ) + ALLOW_CREATE_WORKSPACE: bool = Field( + description="whether to enable create workspace", + default=True, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -627,6 +651,7 @@ class FeatureConfig( WorkflowConfig, WorkspaceConfig, PositionConfig, + LoginConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index ea23e097d0..a13275a7da 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -29,3 +29,15 @@ class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" description = "Password reset rate limit exceeded. Try again later." code = 429 + + +class EmailLoginCodeError(BaseHTTPException): + error_code = "email_login_code_error" + description = "Email login code is invalid or expired." + code = 400 + + +class NotAllowCreateWorkspaceError(BaseHTTPException): + error_code = "workspace_not_found" + description = "Workspace not found." + code = 400 diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 62837af2b9..350276c7bf 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -5,7 +5,15 @@ from flask import request from flask_restful import Resource, reqparse import services +from configs import dify_config +from constants.languages import languages from controllers.console import api +from controllers.console.auth.error import ( + EmailLoginCodeError, + InvalidEmailError, + InvalidTokenError, + NotAllowCreateWorkspaceError, +) from controllers.console.setup import setup_required from libs.helper import email, get_remote_ip from libs.password import valid_password @@ -106,5 +114,64 @@ class ResetPasswordApi(Resource): return {"result": "success"} +class EmailCodeLoginSendEmailApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + args = parser.parse_args() + + account = AccountService.get_user_through_email(args["email"]) + if account is None: + raise InvalidEmailError() + + token = AccountService.send_email_code_login_email(account) + return {"result": "success", "data": token} + + +class EmailCodeLoginApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, location="json") + args = parser.parse_args() + + user_email = args["email"] + + token_data = AccountService.get_email_code_login_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if token_data["email"] != args["email"]: + raise InvalidEmailError() + + if token_data["code"] != args["code"]: + raise EmailLoginCodeError() + + AccountService.revoke_email_code_login_token(args["token"]) + account = AccountService.get_user_through_email(user_email) + if account is None: + # through environment variable, control whether to allow user to register and create workspace + if dify_config.ALLOW_REGISTER: + account = AccountService.create_account( + email=user_email, name=user_email, interface_language=languages[0] + ) + else: + raise InvalidEmailError() + if dify_config.ALLOW_CREATE_WORKSPACE: + TenantService.create_owner_tenant_if_not_exist(account=account) + else: + raise NotAllowCreateWorkspaceError() + + else: + token = AccountService.login(account, ip_address=get_remote_ip(request)) + + return {"result": "success", "data": token} + + api.add_resource(LoginApi, "/login") api.add_resource(LogoutApi, "/logout") +api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/libs/helper.py b/api/libs/helper.py index af0c2dace1..a486bfc872 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -203,7 +203,8 @@ class TokenManager: expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] token_key = cls._get_token_key(token, token_type) - redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) + expiry_time = int(expiry_hours * 60 * 60) + redis_client.setex(token_key, expiry_time, json.dumps(token_data)) cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) return token @@ -234,9 +235,12 @@ class TokenManager: return current_token @classmethod - def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int): + def _set_current_token_for_account( + cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] + ): key = cls._get_account_token_key(account_id, token_type) - redis_client.setex(key, expiry_hours * 60 * 60, token) + expiry_time = int(expiry_hours * 60 * 60) + redis_client.setex(key, expiry_time, token) @classmethod def _get_account_token_key(cls, account_id: str, token_type: str) -> str: diff --git a/api/services/account_service.py b/api/services/account_service.py index cd501c9792..76e1cc64a8 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1,5 +1,6 @@ import base64 import logging +import random import secrets import uuid from datetime import datetime, timedelta, timezone @@ -34,6 +35,7 @@ from services.errors.account import ( RoleAlreadyAssignedError, TenantNotFound, ) +from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -246,6 +248,37 @@ class AccountService: def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") + @classmethod + def send_email_code_login_email(cls, account: Account): + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token(account, "email_code_login", {"code": code}) + send_email_code_login_mail_task.delay( + language=account.interface_language, + to=account.email, + code=code, + ) + + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "email_code_login") + + @classmethod + def get_user_through_email(cls, email: str): + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: + raise Unauthorized("Account is banned or closed.") + + return account + def _get_login_cache_key(*, account_id: str, token: str): return f"account_login:{account_id}:{token}" diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py new file mode 100644 index 0000000000..d78fc2b891 --- /dev/null +++ b/api/tasks/mail_email_code_login.py @@ -0,0 +1,41 @@ +import logging +import time + +import click +from celery import shared_task +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_email_code_login_mail_task(language: str, to: str, code: str): + """ + Async Send email code login mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Email code to be included in the email + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + # send email code login mail using different languages + try: + if language == "zh-Hans": + html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="邮箱验证码", html=html_content) + else: + html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Email Code", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send email code login mail to {} failed".format(to))