From 05903e32511fb3af3c617b3f7cfd7ffdcec6cea9 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Thu, 5 Jun 2025 16:00:37 +0900 Subject: [PATCH] Feat/webapp verified sso 260 (#20496) --- .../console/auth/forgot_password.py | 13 +- .../console/explore/installed_app.py | 7 + api/controllers/console/wraps.py | 16 +- api/controllers/web/__init__.py | 15 +- api/controllers/web/app.py | 17 +- api/controllers/web/forgot_password.py | 147 ++++++++++++++++++ api/controllers/web/login.py | 36 ++--- api/controllers/web/passport.py | 98 +++++++++++- api/controllers/web/wraps.py | 58 +++++-- api/extensions/ext_login.py | 3 + api/services/app_service.py | 12 ++ api/services/enterprise/enterprise_service.py | 30 +++- api/services/webapp_auth_service.py | 78 +++++----- 13 files changed, 447 insertions(+), 83 deletions(-) create mode 100644 api/controllers/web/forgot_password.py diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 5b89005671..41c040ee70 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -6,9 +6,13 @@ from flask_restful import Resource, reqparse # type: ignore from constants.languages import languages from controllers.console import api -from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError -from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError -from controllers.console.wraps import email_password_login_enabled, setup_required +from controllers.console.auth.error import (EmailCodeError, InvalidEmailError, + InvalidTokenError, + PasswordMismatchError) +from controllers.console.error import (AccountInFreezeError, AccountNotFound, + EmailSendIpLimitError) +from controllers.console.wraps import (email_password_login_enabled, + setup_required) from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip @@ -16,7 +20,8 @@ from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError +from services.errors.workspace import (WorkSpaceNotAllowedCreateError, + WorkspacesLimitExceededError) from services.feature_service import FeatureService diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index f533389378..af0dc5868e 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -59,7 +59,14 @@ class InstalledAppsListApi(Resource): if FeatureService.get_system_features().webapp_auth.enabled: user_id = current_user.id res = [] + app_ids = [installed_app["app"].id for installed_app in installed_app_list] + webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids) for installed_app in installed_app_list: + webapp_setting = webapp_settings.get(installed_app["app"].id) + if not webapp_setting: + continue + if webapp_setting.access_mode == "sso_verified": + continue app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( user_id=user_id, diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index b8aece3beb..928ea66c43 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -11,7 +11,8 @@ from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService -from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout +from .error import (NotInitValidateError, NotSetupError, + UnauthorizedAndForceLogout) def account_initialization_required(view): @@ -39,7 +40,18 @@ def only_edition_cloud(view): return decorated -def only_enterprise_edition(view): +def only_edition_enterprise(view): + @wraps(view) + def decorated(*args, **kwargs): + if not dify_config.ENTERPRISE_ENABLED: + abort(404) + + return view(*args, **kwargs) + + return decorated + + +def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): if not dify_config.ENTERPRISE_ENABLED: diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 50a04a6254..56749a0e25 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload") api.add_resource(RemoteFileInfoApi, "/remote-files/") api.add_resource(RemoteFileUploadApi, "/remote-files/upload") -from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow +from . import ( + app, + audio, + completion, + conversation, + feature, + forgot_password, + login, + message, + passport, + saved_message, + site, + workflow, +) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index a3cd17e891..417aac25c8 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -11,6 +11,7 @@ from libs.passport import PassportService from models.model import App, AppMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService +from services.webapp_auth_service import WebAppAuthService class AppParameterApi(WebApiResource): @@ -49,10 +50,18 @@ class AppMeta(WebApiResource): class AppAccessMode(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument("appId", type=str, required=True, location="args") + parser.add_argument("appId", type=str, required=False, location="args") + parser.add_argument("appCode", type=str, required=False, location="args") args = parser.parse_args() - app_id = args["appId"] + app_id = args.get("appId") + if args.get("appCode"): + app_code = args["appCode"] + app_id = AppService.get_app_id_by_code(app_code) + + if not app_id: + raise ValueError("appId or appCode must be provided") + res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) return {"accessMode": res.access_mode} @@ -85,7 +94,9 @@ class AppWebAuthPermission(Resource): app_id = args["appId"] app_code = AppService.get_app_code_by_id(app_id) - res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + res = True + if WebAppAuthService.is_app_require_permission_check(app_id=app_id): + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) return {"result": res} diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py new file mode 100644 index 0000000000..0da8d65efc --- /dev/null +++ b/api/controllers/web/forgot_password.py @@ -0,0 +1,147 @@ +import base64 +import secrets + +from flask import request +from flask_restful import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.auth.error import ( + EmailCodeError, + EmailPasswordResetLimitError, + InvalidEmailError, + InvalidTokenError, + PasswordMismatchError, +) +from controllers.console.error import AccountNotFound, EmailSendIpLimitError +from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required +from controllers.web import api +from extensions.ext_database import db +from libs.helper import email, extract_remote_ip +from libs.password import hash_password, valid_password +from models.account import Account +from services.account_service import AccountService + + +class ForgotPasswordSendEmailApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + token = None + if account is None: + raise AccountNotFound() + else: + token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + + return {"result": "success", "data": token} + + +class ForgotPasswordCheckApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + if is_forgot_password_error_rate_limit: + raise EmailPasswordResetLimitError() + + token_data = AccountService.get_reset_password_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_forgot_password_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_reset_password_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_reset_password_token( + user_email, code=args["code"], additional_data={"phase": "reset"} + ) + + AccountService.reset_forgot_password_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class ForgotPasswordResetApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + args = parser.parse_args() + + # Validate passwords match + if args["new_password"] != args["password_confirm"]: + raise PasswordMismatchError() + + # Validate token and get reset data + reset_data = AccountService.get_reset_password_data(args["token"]) + if not reset_data: + raise InvalidTokenError() + # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() + + # Revoke token to prevent reuse + AccountService.revoke_reset_password_token(args["token"]) + + # Generate secure salt and hash password + salt = secrets.token_bytes(16) + password_hashed = hash_password(args["new_password"], salt) + + email = reset_data.get("email", "") + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + + if account: + self._update_existing_account(account, password_hashed, salt, session) + else: + raise AccountNotFound() + + return {"result": "success"} + + def _update_existing_account(self, account, password_hashed, salt, session): + # Update existing account credentials + account.password = base64.b64encode(password_hashed).decode() + account.password_salt = base64.b64encode(salt).decode() + session.commit() + + +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 4106e6a179..97bb90248c 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,13 +1,14 @@ -from flask import request from flask_restful import Resource, reqparse from jwt import InvalidTokenError # type: ignore from web import api -from werkzeug.exceptions import BadRequest import services -from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError +from controllers.console.auth.error import (EmailCodeError, + EmailOrPasswordMismatchError, + InvalidEmailError) from controllers.console.error import AccountBannedError, AccountNotFound -from controllers.console.wraps import setup_required +from controllers.console.wraps import only_edition_enterprise, setup_required +from controllers.web import api from libs.helper import email from libs.password import valid_password from services.account_service import AccountService @@ -17,6 +18,8 @@ from services.webapp_auth_service import WebAppAuthService class LoginApi(Resource): """Resource for web app email/password login.""" + @setup_required + @only_edition_enterprise def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() @@ -24,10 +27,6 @@ class LoginApi(Resource): parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() - app_code = request.headers.get("X-App-Code") - if app_code is None: - raise BadRequest("X-App-Code header is missing.") - try: account = WebAppAuthService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError: @@ -37,12 +36,8 @@ class LoginApi(Resource): except services.errors.account.AccountNotFoundError: raise AccountNotFound() - WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) - - end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code) - - token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) - return {"result": "success", "token": token} + token = WebAppAuthService.login(account=account) + return {"result": "success", "data": {"access_token": token}} # class LogoutApi(Resource): @@ -57,6 +52,7 @@ class LoginApi(Resource): class EmailCodeLoginSendEmailApi(Resource): @setup_required + @only_edition_enterprise def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") @@ -79,6 +75,7 @@ class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginApi(Resource): @setup_required + @only_edition_enterprise def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=str, required=True, location="json") @@ -87,9 +84,6 @@ class EmailCodeLoginApi(Resource): args = parser.parse_args() user_email = args["email"] - app_code = request.headers.get("X-App-Code") - if app_code is None: - raise BadRequest("X-App-Code header is missing.") token_data = WebAppAuthService.get_email_code_login_data(args["token"]) if token_data is None: @@ -106,13 +100,9 @@ class EmailCodeLoginApi(Resource): if not account: raise AccountNotFound() - WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) - - end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code) - - token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) + token = WebAppAuthService.login(account=account) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "token": token} + return {"result": "success", "data": {"access_token": token}} api.add_resource(LoginApi, "/login") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 8ab9b84574..61b44aa170 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,9 +1,11 @@ import uuid +from datetime import UTC, datetime, timedelta from flask import request -from flask_restful import Resource # type: ignore +from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized +from configs import dify_config from controllers.web import api from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db @@ -19,9 +21,19 @@ class PassportResource(Resource): def get(self): system_features = FeatureService.get_system_features() app_code = request.headers.get("X-App-Code") + web_app_access_token = request.args.get("web_app_access_token") + if app_code is None: raise Unauthorized("X-App-Code header is missing.") + # exchange token for enterprise logined web user + enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token) + if enterprise_user_decoded: + # a web user has already logged in, exchange a token for this app without redirecting to the login page + return exchange_token_for_existing_web_user( + app_code=app_code, enterprise_user_decoded=enterprise_user_decoded + ) + if system_features.webapp_auth.enabled: app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) if not app_settings or not app_settings.access_mode == "public": @@ -65,6 +77,90 @@ class PassportResource(Resource): api.add_resource(PassportResource, "/passport") +def decode_enterprise_webapp_user_id(jwt_token: str | None): + """ + Decode the enterprise user session from the Authorization header. + """ + if not jwt_token: + return None + + decoded = PassportService().verify(jwt_token) + source = decoded.get("token_source") + if not source or source != "webapp_login_token": + raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") + return decoded + + +def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict): + """ + Exchange a token for an existing web user session. + """ + user_id = enterprise_user_decoded.get("user_id") + end_user_id = enterprise_user_decoded.get("end_user_id") + session_id = enterprise_user_decoded.get("session_id") + auth_type = enterprise_user_decoded.get("auth_type") + + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + if not site: + raise NotFound() + + app_model = db.session.query(App).filter(App.id == site.app_id).first() + if not app_model or app_model.status != "normal" or not app_model.enable_site: + raise NotFound() + + if not auth_type: + raise Unauthorized("Missing auth_type in the token.") + settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + if settings.access_mode == "sso_verified" and auth_type != "external": + raise WebAppAuthRequiredError("Please login as external user.") + elif settings.access_mode in ["private", "private_all"] and auth_type == "external": + raise WebAppAuthRequiredError("Please login as internal user.") + end_user = None + if end_user_id: + end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + if session_id: + end_user = ( + db.session.query(EndUser) + .filter( + EndUser.session_id == session_id, + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + ) + .first() + ) + if not end_user: + if not session_id: + raise NotFound("Missing session_id for existing web user.") + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=True, + session_id=session_id, + ) + db.session.add(end_user) + db.session.commit() + + exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) + exp = int(exp_dt.timestamp()) + payload = { + "iss": site.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "user_id": user_id, + "end_user_id": end_user.id, + "auth_type": auth_type, + "granted_at": int(datetime.now(UTC).timestamp()), + "token_source": "webapp", + "exp": exp, + } + token: str = PassportService().issue(payload) + return { + "access_token": token, + } + + def generate_session_id(): """ Generate a unique session ID. diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 8d35b8e4be..d3e3e4261b 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,15 +1,19 @@ +from datetime import UTC, datetime from functools import wraps from flask import request from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError +from controllers.web.error import (WebAppAuthAccessDeniedError, + WebAppAuthRequiredError) from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site -from services.enterprise.enterprise_service import EnterpriseService +from services.enterprise.enterprise_service import (EnterpriseService, + WebAppSettings) from services.feature_service import FeatureService +from services.webapp_auth_service import WebAppAuthService def validate_jwt_token(view=None): @@ -45,7 +49,8 @@ def decode_jwt_token(): raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) app_code = decoded.get("app_code") - app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() + app_id = decoded.get("app_id") + app_model = db.session.query(App).filter(App.id == app_id).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() @@ -53,19 +58,24 @@ def decode_jwt_token(): raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: raise BadRequest("Site is disabled.") - end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + end_user_id = decoded.get("end_user_id") + end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() if not end_user: raise NotFound() # for enterprise webapp auth app_web_auth_enabled = False + webapp_settings = None if system_features.webapp_auth.enabled: - app_web_auth_enabled = ( - EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" - ) + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + if not webapp_settings: + raise NotFound("Web app settings not found.") + app_web_auth_enabled = webapp_settings.access_mode != "public" _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) - _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) + _validate_user_accessibility( + decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings + ) return app_model, end_user except Unauthorized as e: @@ -95,15 +105,41 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au raise Unauthorized("webapp token expired.") -def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): +def _validate_user_accessibility( + decoded, + app_code, + app_web_auth_enabled: bool, + system_webapp_auth_enabled: bool, + webapp_settings: WebAppSettings | None, +): if system_webapp_auth_enabled and app_web_auth_enabled: # Check if the user is allowed to access the web app user_id = decoded.get("user_id") if not user_id: raise WebAppAuthRequiredError() - if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): - raise WebAppAuthAccessDeniedError() + if not webapp_settings: + raise WebAppAuthRequiredError("Web app settings not found.") + + if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode): + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + raise WebAppAuthAccessDeniedError() + + auth_type = decoded.get("auth_type") + granted_at = decoded.get("granted_at") + if not auth_type: + raise WebAppAuthAccessDeniedError("Missing auth_type in the token.") + if not granted_at: + raise WebAppAuthAccessDeniedError("Missing granted_at in the token.") + # check if sso has been updated + if auth_type == "external": + last_update_time = EnterpriseService.get_app_sso_settings_last_update_time() + if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: + raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") + elif auth_type == "internal": + last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time() + if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: + raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") class WebApiResource(Resource): diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 10fb89eb73..d23ca96ec2 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -35,6 +35,9 @@ def load_user_from_request(request_from_flask_login): decoded = PassportService().verify(auth_token) user_id = decoded.get("user_id") + source = decoded.get("token_source") + if source: + raise Unauthorized("Invalid Authorization token.") logged_in_account = AccountService.load_logged_in_account(account_id=user_id) return logged_in_account diff --git a/api/services/app_service.py b/api/services/app_service.py index e6a1ae32a9..9fcdf46513 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -396,3 +396,15 @@ class AppService: if not site: raise ValueError(f"App with id {app_id} not found") return str(site.code) + + @staticmethod + def get_app_id_by_code(app_code: str) -> str: + """ + Get app id by app code + :param app_code: app code + :return: app id + """ + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise ValueError(f"App with code {app_code} not found") + return str(site.app_id) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 9a0c478e75..98147003ff 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,4 +1,6 @@ +from datetime import datetime + from pydantic import BaseModel, Field from services.enterprise.base import EnterpriseRequest @@ -6,7 +8,7 @@ from services.enterprise.base import EnterpriseRequest class WebAppSettings(BaseModel): access_mode: str = Field( - description="Access mode for the web app. Can be 'public' or 'private'", + description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'", default="private", alias="accessMode", ) @@ -18,9 +20,33 @@ class EnterpriseService: return EnterpriseRequest.send_request("GET", "/info") @classmethod - def get_workspace_info(cls, tenant_id:str): + def get_workspace_info(cls, tenant_id: str): return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") + @classmethod + def get_app_sso_settings_last_update_time(cls) -> datetime: + data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time") + print(data) + if not data: + raise ValueError("No data found.") + try: + # parse the UTC timestamp from the response + return datetime.fromisoformat(data.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError(f"Invalid date format: {data}") from e + + @classmethod + def get_workspace_sso_settings_last_update_time(cls) -> datetime: + data = EnterpriseRequest.send_request("GET", "/sso/workspace/last-update-time") + print(data) + if not data: + raise ValueError("No data found.") + try: + # parse the UTC timestamp from the response + return datetime.fromisoformat(data.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError(f"Invalid date format: {data}") from e + class WebAppAuth: @classmethod def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str) -> bool: diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 506b7698e0..b73463c29f 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -2,20 +2,19 @@ import random from datetime import UTC, datetime, timedelta from typing import Any, Optional, cast -from werkzeug.exceptions import NotFound, Unauthorized - from configs import dify_config -from controllers.web.error import WebAppAuthAccessDeniedError from extensions.ext_database import db from libs.helper import TokenManager from libs.passport import PassportService from libs.password import compare_password from models.account import Account, AccountStatus from models.model import App, EndUser, Site +from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService -from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError -from services.feature_service import FeatureService +from services.errors.account import (AccountLoginError, AccountNotFoundError, + AccountPasswordError) from tasks.mail_email_code_login import send_email_code_login_mail_task +from werkzeug.exceptions import Unauthorized class WebAppAuthService: @@ -24,8 +23,7 @@ class WebAppAuthService: @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - - account = Account.query.filter_by(email=email).first() + account = db.session.query(Account).filter_by(email=email).first() if not account: raise AccountNotFoundError() @@ -38,12 +36,8 @@ class WebAppAuthService: return cast(Account, account) @classmethod - def login(cls, account: Account, app_code: str, end_user_id: str) -> str: - site = db.session.query(Site).filter(Site.code == app_code).first() - if not site: - raise NotFound("Site not found.") - - access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id) + def login(cls, account: Account) -> str: + access_token = cls._get_account_jwt_token(account=account) return access_token @@ -68,7 +62,7 @@ class WebAppAuthService: code = "".join([str(random.randint(0, 9)) for _ in range(6)]) token = TokenManager.generate_token( - account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code} + account=account, email=email, token_type="email_code_login", additional_data={"code": code} ) send_email_code_login_mail_task.delay( language=language, @@ -80,11 +74,11 @@ class WebAppAuthService: @classmethod def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: - return TokenManager.get_token_data(token, "webapp_email_code_login") + return TokenManager.get_token_data(token, "email_code_login") @classmethod def revoke_email_code_login_token(cls, token: str): - TokenManager.revoke_token(token, "webapp_email_code_login") + TokenManager.revoke_token(token, "email_code_login") @classmethod def create_end_user(cls, app_code, email) -> EndUser: @@ -105,33 +99,45 @@ class WebAppAuthService: return end_user @classmethod - def _validate_user_accessibility(cls, account: Account, app_code: str): - """Check if the user is allowed to access the app.""" - system_features = FeatureService.get_system_features() - if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) - - if ( - app_settings.access_mode != "public" - and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code) - ): - raise WebAppAuthAccessDeniedError() - - @classmethod - def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str: - exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24) + def _get_account_jwt_token(cls, account: Account) -> str: + exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) exp = int(exp_dt.timestamp()) payload = { - "iss": site.id, "sub": "Web API Passport", - "app_id": site.app_id, - "app_code": site.code, "user_id": account.id, - "end_user_id": end_user_id, - "token_source": "webapp", + "session_id": account.email, + "token_source": "webapp_login_token", + "auth_type": "internal", "exp": exp, } token: str = PassportService().issue(payload) return token + + @classmethod + def is_app_require_permission_check( + cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None + ) -> bool: + """ + Check if the app requires permission check based on its access mode. + """ + modes_requiring_permission_check = [ + "private", + "private_all", + ] + if access_mode: + return access_mode in modes_requiring_permission_check + + if not app_code and not app_id: + raise ValueError("Either app_code or app_id must be provided.") + + if app_code: + app_id = AppService.get_app_id_by_code(app_code) + if not app_id: + raise ValueError("App ID could not be determined from the provided app_code.") + + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) + if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check: + return True + return False