From 9a5f21462361c5154a5e785c906bb22d1b3c6931 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 19 Oct 2025 21:29:04 +0800 Subject: [PATCH] refactor: replace localStorage with HTTP-only cookies for auth tokens (#24365) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Signed-off-by: lyzno1 Signed-off-by: kenwoodjw Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Yunlu Wen Co-authored-by: Joel Co-authored-by: GareArc Co-authored-by: NFish Co-authored-by: Davide Delbianco Co-authored-by: minglu7 <1347866672@qq.com> Co-authored-by: Ponder Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: heyszt <270985384@qq.com> Co-authored-by: Asuka Minato Co-authored-by: Guangdong Liu Co-authored-by: Eric Guo Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: XlKsyt Co-authored-by: Dhruv Gorasiya <80987415+DhruvGorasiya@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: hj24 Co-authored-by: GuanMu Co-authored-by: 非法操作 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tonlo <123lzs123@gmail.com> Co-authored-by: Yusuke Yamada Co-authored-by: Novice Co-authored-by: kenwoodjw Co-authored-by: Ademílson Tonato Co-authored-by: znn Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com> --- api/constants/__init__.py | 9 + api/controllers/console/admin.py | 15 +- api/controllers/console/auth/login.py | 81 +++++-- api/controllers/console/auth/oauth.py | 14 +- .../console/explore/installed_app.py | 12 +- api/controllers/console/explore/wraps.py | 4 +- api/controllers/web/app.py | 33 ++- api/controllers/web/login.py | 90 +++++++- api/controllers/web/passport.py | 42 ++-- api/controllers/web/wraps.py | 32 ++- api/extensions/ext_blueprints.py | 9 +- api/extensions/ext_login.py | 15 +- api/libs/external_api.py | 15 ++ api/libs/login.py | 4 + api/libs/token.py | 208 ++++++++++++++++++ api/services/account_service.py | 8 +- api/services/enterprise/enterprise_service.py | 10 +- api/services/webapp_auth_service.py | 3 +- .../services/test_webapp_auth_service.py | 9 +- .../controllers/console/auth/test_oauth.py | 20 +- .../unit_tests/libs/test_external_api.py | 65 ++++++ api/tests/unit_tests/libs/test_login.py | 11 + api/tests/unit_tests/libs/test_token.py | 23 ++ .../components/authenticated-layout.tsx | 9 +- web/app/(shareLayout)/components/splash.tsx | 88 +++++--- .../webapp-signin/check-code/page.tsx | 8 +- .../components/mail-and-password-auth.tsx | 23 +- web/app/(shareLayout)/webapp-signin/page.tsx | 9 +- .../account-page/email-change-modal.tsx | 11 +- web/app/account/(commonLayout)/avatar.tsx | 11 +- .../delete-account/components/feed-back.tsx | 11 +- web/app/account/oauth/authorize/layout.tsx | 19 +- web/app/account/oauth/authorize/page.tsx | 15 +- .../access-control-dialog.tsx | 4 +- .../add-member-or-group-pop.tsx | 2 +- .../base/chat/chat-with-history/index.tsx | 33 --- .../header/account-dropdown/index.tsx | 11 +- .../hooks/use-nodes-sync-draft.ts | 2 +- .../share/text-generation/menu-dropdown.tsx | 9 +- web/app/components/share/utils.ts | 56 ----- web/app/components/swr-initializer.tsx | 28 +-- .../hooks/use-nodes-sync-draft.ts | 2 +- web/app/education-apply/user-info.tsx | 11 +- web/app/install/installForm.tsx | 2 - web/app/signin/check-code/page.tsx | 2 - .../components/mail-and-password-auth.tsx | 3 +- web/app/signin/invite-settings/page.tsx | 3 +- web/app/signin/normal-form.tsx | 18 +- web/app/signup/set-password/page.tsx | 4 +- web/config/index.ts | 11 + web/context/web-app-context.tsx | 17 +- web/models/app.ts | 57 ----- web/service/base.ts | 43 ++-- web/service/common.ts | 10 +- web/service/fetch.ts | 49 ++--- web/service/refresh-token.ts | 8 +- web/service/share.ts | 14 +- web/service/use-common.ts | 22 +- web/service/use-share.ts | 2 + web/service/webapp-auth.ts | 53 +++++ 60 files changed, 879 insertions(+), 533 deletions(-) create mode 100644 api/libs/token.py create mode 100644 api/tests/unit_tests/libs/test_token.py create mode 100644 web/service/webapp-auth.ts diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 9141fbea95..248cdfc09f 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -55,3 +55,12 @@ else: "properties", } DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions) + +COOKIE_NAME_ACCESS_TOKEN = "access_token" +COOKIE_NAME_REFRESH_TOKEN = "refresh_token" +COOKIE_NAME_PASSPORT = "passport" +COOKIE_NAME_CSRF_TOKEN = "csrf_token" + +HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token" +HEADER_NAME_APP_CODE = "X-App-Code" +HEADER_NAME_PASSPORT = "X-App-Passport" diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index ef96184678..2c4d8709eb 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,6 +15,7 @@ from constants.languages import supported_language from controllers.console import api, console_ns from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db +from libs.token import extract_access_token from models.model import App, InstalledApp, RecommendedApp @@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]): if not dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") - auth_header = request.headers.get("Authorization") - if auth_header is None: + auth_token = extract_access_token(request) + if not auth_token: raise Unauthorized("Authorization header is missing.") - - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if auth_token != dify_config.ADMIN_API_KEY: raise Unauthorized("API key is invalid.") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 3696c88346..277f9a60a8 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,5 +1,5 @@ import flask_login -from flask import request +from flask import make_response, request from flask_restx import Resource, reqparse import services @@ -25,6 +25,16 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.login import current_account_with_tenant +from libs.token import ( + clear_access_token_from_cookie, + clear_csrf_token_from_cookie, + clear_refresh_token_from_cookie, + extract_access_token, + extract_csrf_token, + set_access_token_to_cookie, + set_csrf_token_to_cookie, + set_refresh_token_to_cookie, +) from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService from services.errors.account import AccountRegisterError @@ -89,20 +99,36 @@ class LoginApi(Resource): token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "data": token_pair.model_dump()} + + # Create response with cookies instead of returning tokens in body + response = make_response({"result": "success"}) + + set_access_token_to_cookie(request, response, token_pair.access_token) + set_refresh_token_to_cookie(request, response, token_pair.refresh_token) + set_csrf_token_to_cookie(request, response, token_pair.csrf_token) + + return response @console_ns.route("/logout") class LogoutApi(Resource): @setup_required - def get(self): + def post(self): current_user, _ = current_account_with_tenant() account = current_user if isinstance(account, flask_login.AnonymousUserMixin): - return {"result": "success"} - AccountService.logout(account=account) - flask_login.logout_user() - return {"result": "success"} + response = make_response({"result": "success"}) + else: + AccountService.logout(account=account) + flask_login.logout_user() + response = make_response({"result": "success"}) + + # Clear cookies on logout + clear_access_token_from_cookie(response) + clear_refresh_token_from_cookie(response) + clear_csrf_token_from_cookie(response) + + return response @console_ns.route("/reset-password") @@ -227,17 +253,46 @@ class EmailCodeLoginApi(Resource): raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "data": token_pair.model_dump()} + + # Create response with cookies instead of returning tokens in body + response = make_response({"result": "success"}) + + set_csrf_token_to_cookie(request, response, token_pair.csrf_token) + # Set HTTP-only secure cookies for tokens + set_access_token_to_cookie(request, response, token_pair.access_token) + set_refresh_token_to_cookie(request, response, token_pair.refresh_token) + return response @console_ns.route("/refresh-token") class RefreshTokenApi(Resource): def post(self): - parser = reqparse.RequestParser().add_argument("refresh_token", type=str, required=True, location="json") - args = parser.parse_args() + # Get refresh token from cookie instead of request body + refresh_token = request.cookies.get("refresh_token") + + if not refresh_token: + return {"result": "fail", "message": "No refresh token provided"}, 401 try: - new_token_pair = AccountService.refresh_token(args["refresh_token"]) - return {"result": "success", "data": new_token_pair.model_dump()} + new_token_pair = AccountService.refresh_token(refresh_token) + + # Create response with new cookies + response = make_response({"result": "success"}) + + # Update cookies with new tokens + set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token) + set_access_token_to_cookie(request, response, new_token_pair.access_token) + set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token) + return response except Exception as e: - return {"result": "fail", "data": str(e)}, 401 + return {"result": "fail", "message": str(e)}, 401 + + +# this api helps frontend to check whether user is authenticated +# TODO: remove in the future. frontend should redirect to login page by catching 401 status +@console_ns.route("/login/status") +class LoginStatus(Resource): + def get(self): + token = extract_access_token(request) + csrf_token = extract_csrf_token(request) + return {"logged_in": bool(token) and bool(csrf_token)} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 52459ad5eb..29653b32ec 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -14,6 +14,11 @@ from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo +from libs.token import ( + set_access_token_to_cookie, + set_csrf_token_to_cookie, + set_refresh_token_to_cookie, +) from models import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService @@ -152,9 +157,12 @@ class OAuthCallback(Resource): ip_address=extract_remote_ip(request), ) - return redirect( - f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" - ) + response = redirect(f"{dify_config.CONSOLE_WEB_URL}") + + set_access_token_to_cookie(request, response, token_pair.access_token) + set_refresh_token_to_cookie(request, response, token_pair.refresh_token) + set_csrf_token_to_cookie(request, response, token_pair.csrf_token) + return response def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index dec84b68f4..3c95779475 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -15,7 +15,6 @@ from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService -from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -67,31 +66,26 @@ class InstalledAppsListApi(Resource): # Pre-filter out apps without setting or with sso_verified filtered_installed_apps = [] - app_id_to_app_code = {} for installed_app in installed_app_list: app_id = installed_app["app"].id webapp_setting = webapp_settings.get(app_id) if not webapp_setting or webapp_setting.access_mode == "sso_verified": continue - app_code = AppService.get_app_code_by_id(str(app_id)) - app_id_to_app_code[app_id] = app_code filtered_installed_apps.append(installed_app) - app_codes = list(app_id_to_app_code.values()) - # Batch permission check + app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps] permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( user_id=user_id, - app_codes=app_codes, + app_ids=app_ids, ) # Keep only allowed apps res = [] for installed_app in filtered_installed_apps: app_id = installed_app["app"].id - app_code = app_id_to_app_code[app_id] - if permissions.get(app_code): + if permissions.get(app_id): res.append(installed_app) installed_app_list = res diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index df4eed18eb..2a97d312aa 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -10,7 +10,6 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import InstalledApp -from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -56,10 +55,9 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | feature = FeatureService.get_system_features() if feature.webapp_auth.enabled: app_id = installed_app.app_id - app_code = AppService.get_app_code_by_id(app_id) res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( user_id=str(current_user.id), - app_code=app_code, + app_id=app_id, ) if not res: raise AppAccessDeniedError() diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index d7facdbbb3..60193f5f15 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -4,12 +4,14 @@ from flask import request from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Unauthorized +from constants import HEADER_NAME_APP_CODE from controllers.common import fields from controllers.web import web_ns from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from libs.passport import PassportService +from libs.token import extract_webapp_passport from models.model import App, AppMode from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -133,18 +135,19 @@ class AppWebAuthPermission(Resource): ) def get(self): user_id = "visitor" + app_code = request.headers.get(HEADER_NAME_APP_CODE) + app_id = request.args.get("appId") + if not app_id or not app_code: + raise ValueError("appId must be provided") + + require_permission_check = WebAppAuthService.is_app_require_permission_check(app_id=app_id) + if not require_permission_check: + return {"result": True} + try: - auth_header = request.headers.get("Authorization") - if auth_header is None: - raise Unauthorized("Authorization header is missing.") - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - auth_scheme, tk = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Authorization scheme must be 'Bearer'") - + tk = extract_webapp_passport(app_code, request) + if not tk: + raise Unauthorized("Access token is missing.") decoded = PassportService().verify(tk) user_id = decoded.get("user_id", "visitor") except Unauthorized: @@ -157,13 +160,7 @@ class AppWebAuthPermission(Resource): if not features.webapp_auth.enabled: return {"result": True} - parser = reqparse.RequestParser().add_argument("appId", type=str, required=True, location="args") - args = parser.parse_args() - - app_id = args["appId"] - app_code = AppService.get_app_code_by_id(app_id) - res = True if WebAppAuthService.is_app_require_permission_check(app_id=app_id): - res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_id) return {"result": res} diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 351f245f4a..f213fd8c90 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,7 +1,9 @@ +from flask import make_response, request from flask_restx import Resource, reqparse from jwt import InvalidTokenError import services +from configs import dify_config from controllers.console.auth.error import ( AuthenticationFailedError, EmailCodeError, @@ -10,9 +12,16 @@ from controllers.console.auth.error import ( from controllers.console.error import AccountBannedError from controllers.console.wraps import only_edition_enterprise, setup_required from controllers.web import web_ns +from controllers.web.wraps import decode_jwt_token from libs.helper import email +from libs.passport import PassportService from libs.password import valid_password +from libs.token import ( + clear_access_token_from_cookie, + extract_access_token, +) from services.account_service import AccountService +from services.app_service import AppService from services.webapp_auth_service import WebAppAuthService @@ -52,17 +61,75 @@ class LoginApi(Resource): raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) - return {"result": "success", "data": {"access_token": token}} + response = make_response({"result": "success", "data": {"access_token": token}}) + # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) + return response -# class LogoutApi(Resource): -# @setup_required -# def get(self): -# account = cast(Account, flask_login.current_user) -# if isinstance(account, flask_login.AnonymousUserMixin): -# return {"result": "success"} -# flask_login.logout_user() -# return {"result": "success"} +# this api helps frontend to check whether user is authenticated +# TODO: remove in the future. frontend should redirect to login page by catching 401 status +@web_ns.route("/login/status") +class LoginStatusApi(Resource): + @setup_required + @web_ns.doc("web_app_login_status") + @web_ns.doc(description="Check login status") + @web_ns.doc( + responses={ + 200: "Login status", + 401: "Login status", + } + ) + def get(self): + app_code = request.args.get("app_code") + token = extract_access_token(request) + if not app_code: + return { + "logged_in": bool(token), + "app_logged_in": False, + } + app_id = AppService.get_app_id_by_code(app_code) + is_public = not dify_config.ENTERPRISE_ENABLED or not WebAppAuthService.is_app_require_permission_check( + app_id=app_id + ) + user_logged_in = False + + if is_public: + user_logged_in = True + else: + try: + PassportService().verify(token=token) + user_logged_in = True + except Exception: + user_logged_in = False + + try: + _ = decode_jwt_token(app_code=app_code) + app_logged_in = True + except Exception: + app_logged_in = False + + return { + "logged_in": user_logged_in, + "app_logged_in": app_logged_in, + } + + +@web_ns.route("/logout") +class LogoutApi(Resource): + @setup_required + @web_ns.doc("web_app_logout") + @web_ns.doc(description="Logout user from web application") + @web_ns.doc( + responses={ + 200: "Logout successful", + } + ) + def post(self): + response = make_response({"result": "success"}) + # enterprise SSO sets same site to None in https deployment + # so we need to logout by calling api + clear_access_token_from_cookie(response, samesite="None") + return response @web_ns.route("/email-code-login") @@ -96,7 +163,6 @@ class EmailCodeLoginSendEmailApi(Resource): raise AuthenticationFailedError() else: token = WebAppAuthService.send_email_code_login_email(account=account, language=language) - return {"result": "success", "data": token} @@ -142,4 +208,6 @@ class EmailCodeLoginApi(Resource): token = WebAppAuthService.login(account=account) AccountService.reset_login_error_rate_limit(args["email"]) - return {"result": "success", "data": {"access_token": token}} + response = make_response({"result": "success", "data": {"access_token": token}}) + # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) + return response diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 7190f06426..776b743e92 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,17 +1,20 @@ import uuid from datetime import UTC, datetime, timedelta -from flask import request +from flask import make_response, request from flask_restx import Resource from sqlalchemy import func, select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config +from constants import HEADER_NAME_APP_CODE from controllers.web import web_ns from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService +from libs.token import extract_access_token from models.model import App, EndUser, Site +from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService, WebAppAuthType @@ -32,15 +35,15 @@ class PassportResource(Resource): ) def get(self): system_features = FeatureService.get_system_features() - app_code = request.headers.get("X-App-Code") + app_code = request.headers.get(HEADER_NAME_APP_CODE) user_id = request.args.get("user_id") - web_app_access_token = request.args.get("web_app_access_token") + access_token = extract_access_token(request) if app_code is None: raise Unauthorized("X-App-Code header is missing.") - + app_id = AppService.get_app_id_by_code(app_code) # exchange token for enterprise logined web user - enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token) + enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token) if enterprise_user_decoded: # a web user has already logged in, exchange a token for this app without redirecting to the login page return exchange_token_for_existing_web_user( @@ -48,7 +51,7 @@ class PassportResource(Resource): ) if system_features.webapp_auth.enabled: - app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) if not app_settings or not app_settings.access_mode == "public": raise WebAppAuthRequiredError() @@ -99,9 +102,12 @@ class PassportResource(Resource): tk = PassportService().issue(payload) - return { - "access_token": tk, - } + response = make_response( + { + "access_token": tk, + } + ) + return response def decode_enterprise_webapp_user_id(jwt_token: str | None): @@ -189,9 +195,12 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: "exp": exp, } token: str = PassportService().issue(payload) - return { - "access_token": token, - } + resp = make_response( + { + "access_token": token, + } + ) + return resp def _exchange_for_public_app_token(app_model, site, token_decoded): @@ -224,9 +233,12 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): tk = PassportService().issue(payload) - return { - "access_token": tk, - } + resp = make_response( + { + "access_token": tk, + } + ) + return resp def generate_session_id(): diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index ba03c4eae4..9efd9f25d1 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -9,10 +9,13 @@ from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound, Unauthorized +from constants import HEADER_NAME_APP_CODE from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService +from libs.token import extract_webapp_passport from models.model import App, EndUser, Site +from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings from services.feature_service import FeatureService from services.webapp_auth_service import WebAppAuthService @@ -35,22 +38,14 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = return decorator -def decode_jwt_token(): +def decode_jwt_token(app_code: str | None = None): system_features = FeatureService.get_system_features() - app_code = str(request.headers.get("X-App-Code")) + if not app_code: + app_code = str(request.headers.get(HEADER_NAME_APP_CODE)) try: - auth_header = request.headers.get("Authorization") - if auth_header is None: - raise Unauthorized("Authorization header is missing.") - - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - auth_scheme, tk = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + tk = extract_webapp_passport(app_code, request) + if not tk: + raise Unauthorized("App token is missing.") decoded = PassportService().verify(tk) app_code = decoded.get("app_code") app_id = decoded.get("app_id") @@ -72,7 +67,8 @@ def decode_jwt_token(): app_web_auth_enabled = False webapp_settings = None if system_features.webapp_auth.enabled: - webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + app_id = AppService.get_app_id_by_code(app_code) + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) if not webapp_settings: raise NotFound("Web app settings not found.") app_web_auth_enabled = webapp_settings.access_mode != "public" @@ -87,8 +83,9 @@ def decode_jwt_token(): if system_features.webapp_auth.enabled: if not app_code: raise Unauthorized("Please re-login to access the web app.") + app_id = AppService.get_app_id_by_code(app_code) app_web_auth_enabled = ( - EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public" + EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public" ) if app_web_auth_enabled: raise WebAppAuthRequiredError() @@ -129,7 +126,8 @@ def _validate_user_accessibility( raise WebAppAuthRequiredError("Web app settings not found.") if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode): - if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + app_id = AppService.get_app_id_by_code(app_code) + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_id): raise WebAppAuthAccessDeniedError() auth_type = decoded.get("auth_type") diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 9c08a08c45..52fef4929f 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -1,4 +1,5 @@ from configs import dify_config +from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN from dify_app import DifyApp @@ -16,7 +17,7 @@ def init_app(app: DifyApp): CORS( service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], + allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(service_api_bp) @@ -25,7 +26,7 @@ def init_app(app: DifyApp): web_bp, resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], + allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -35,7 +36,7 @@ def init_app(app: DifyApp): console_app_bp, resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], + allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) @@ -43,7 +44,7 @@ def init_app(app: DifyApp): CORS( files_bp, - allow_headers=["Content-Type"], + allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], ) app.register_blueprint(files_bp) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 836a5d938c..e7816a2e88 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -9,6 +9,7 @@ from configs import dify_config from dify_app import DifyApp from extensions.ext_database import db from libs.passport import PassportService +from libs.token import extract_access_token from models import Account, Tenant, TenantAccountJoin from models.model import AppMCPServer, EndUser from services.account_service import AccountService @@ -24,20 +25,10 @@ def load_user_from_request(request_from_flask_login): if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")): return None - auth_header = request.headers.get("Authorization", "") - auth_token: str | None = None - if auth_header: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(maxsplit=1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - else: - auth_token = request.args.get("_token") + auth_token = extract_access_token(request) # Check for admin API key authentication first - if dify_config.ADMIN_API_KEY_ENABLE and auth_header: + if dify_config.ADMIN_API_KEY_ENABLE and auth_token: admin_api_key = dify_config.ADMIN_API_KEY if admin_api_key and admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") diff --git a/api/libs/external_api.py b/api/libs/external_api.py index a59230caaa..f3ebcc4306 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -9,7 +9,9 @@ from werkzeug.exceptions import HTTPException from werkzeug.http import HTTP_STATUS_CODES from configs import dify_config +from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN from core.errors.error import AppInvokeQuotaExceededError +from libs.token import is_secure def http_status_message(code): @@ -67,6 +69,19 @@ def register_external_error_handlers(api: Api): # If you need WWW-Authenticate for 401, add it to headers if status_code == 401: headers["WWW-Authenticate"] = 'Bearer realm="api"' + # Check if this is a forced logout error - clear cookies + error_code = getattr(e, "error_code", None) + if error_code == "unauthorized_and_force_logout": + # Add Set-Cookie headers to clear auth cookies + + secure = is_secure() + # response is not accessible, so we need to do it ugly + common_part = "Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly" + headers["Set-Cookie"] = [ + f'{COOKIE_NAME_ACCESS_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax', + f'{COOKIE_NAME_CSRF_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax', + f'{COOKIE_NAME_REFRESH_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax', + ] return data, status_code, headers _ = handle_http_exception diff --git a/api/libs/login.py b/api/libs/login.py index d0e81a3441..5ed4bfae8f 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -7,6 +7,7 @@ from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.local import LocalProxy from configs import dify_config +from libs.token import check_csrf_token from models import Account from models.model import EndUser @@ -73,6 +74,9 @@ def login_required(func: Callable[P, R]): pass elif current_user is not None and not current_user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + # we put csrf validation here for less conflicts + # TODO: maybe find a better place for it. + check_csrf_token(request, current_user.id) return current_app.ensure_sync(func)(*args, **kwargs) return decorated_view diff --git a/api/libs/token.py b/api/libs/token.py new file mode 100644 index 0000000000..4be25696e7 --- /dev/null +++ b/api/libs/token.py @@ -0,0 +1,208 @@ +import logging +import re +from datetime import UTC, datetime, timedelta + +from flask import Request +from werkzeug.exceptions import Unauthorized +from werkzeug.wrappers import Response + +from configs import dify_config +from constants import ( + COOKIE_NAME_ACCESS_TOKEN, + COOKIE_NAME_CSRF_TOKEN, + COOKIE_NAME_PASSPORT, + COOKIE_NAME_REFRESH_TOKEN, + HEADER_NAME_CSRF_TOKEN, + HEADER_NAME_PASSPORT, +) +from libs.passport import PassportService + +logger = logging.getLogger(__name__) + +CSRF_WHITE_LIST = [ + re.compile(r"/console/api/apps/[a-f0-9-]+/workflows/draft"), +] + + +# server is behind a reverse proxy, so we need to check the url +def is_secure() -> bool: + return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https") + + +def _real_cookie_name(cookie_name: str) -> str: + if is_secure(): + return "__Host-" + cookie_name + else: + return cookie_name + + +def _try_extract_from_header(request: Request) -> str | None: + """ + Try to extract access token from header + """ + auth_header = request.headers.get("Authorization") + if auth_header: + if " " not in auth_header: + return None + else: + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + return None + else: + return auth_token + return None + + +def extract_csrf_token(request: Request) -> str | None: + """ + Try to extract CSRF token from header or cookie. + """ + return request.headers.get(HEADER_NAME_CSRF_TOKEN) + + +def extract_csrf_token_from_cookie(request: Request) -> str | None: + """ + Try to extract CSRF token from cookie. + """ + return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN)) + + +def extract_access_token(request: Request) -> str | None: + """ + Try to extract access token from cookie, header or params. + + Access token is either for console session or webapp passport exchange. + """ + + def _try_extract_from_cookie(request: Request) -> str | None: + return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN)) + + return _try_extract_from_cookie(request) or _try_extract_from_header(request) + + +def extract_webapp_passport(app_code: str, request: Request) -> str | None: + """ + Try to extract app token from header or params. + + Webapp access token (part of passport) is only used for webapp session. + """ + + def _try_extract_passport_token_from_cookie(request: Request) -> str | None: + return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) + + def _try_extract_passport_token_from_header(request: Request) -> str | None: + return request.headers.get(HEADER_NAME_PASSPORT) + + ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request) + return ret + + +def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"): + response.set_cookie( + _real_cookie_name(COOKIE_NAME_ACCESS_TOKEN), + value=token, + httponly=True, + secure=is_secure(), + samesite=samesite, + max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60), + path="/", + ) + + +def set_refresh_token_to_cookie(request: Request, response: Response, token: str): + response.set_cookie( + _real_cookie_name(COOKIE_NAME_REFRESH_TOKEN), + value=token, + httponly=True, + secure=is_secure(), + samesite="Lax", + max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS), + path="/", + ) + + +def set_csrf_token_to_cookie(request: Request, response: Response, token: str): + response.set_cookie( + _real_cookie_name(COOKIE_NAME_CSRF_TOKEN), + value=token, + httponly=False, + secure=is_secure(), + samesite="Lax", + max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES), + path="/", + ) + + +def _clear_cookie( + response: Response, + cookie_name: str, + samesite: str = "Lax", + http_only: bool = True, +): + response.set_cookie( + _real_cookie_name(cookie_name), + "", + expires=0, + path="/", + secure=is_secure(), + httponly=http_only, + samesite=samesite, + ) + + +def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"): + _clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite) + + +def clear_refresh_token_from_cookie(response: Response): + _clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN) + + +def clear_csrf_token_from_cookie(response: Response): + _clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False) + + +def check_csrf_token(request: Request, user_id: str): + # some apis are sent by beacon, so we need to bypass csrf token check + # since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required. + def _unauthorized(): + raise Unauthorized("CSRF token is missing or invalid.") + + for pattern in CSRF_WHITE_LIST: + if pattern.match(request.path): + return + + csrf_token = extract_csrf_token(request) + csrf_token_from_cookie = extract_csrf_token_from_cookie(request) + + if csrf_token != csrf_token_from_cookie: + _unauthorized() + + if not csrf_token: + _unauthorized() + verified = {} + try: + verified = PassportService().verify(csrf_token) + except: + _unauthorized() + + if verified.get("sub") != user_id: + _unauthorized() + + exp: int | None = verified.get("exp") + if not exp: + _unauthorized() + else: + time_now = int(datetime.now().timestamp()) + if exp < time_now: + _unauthorized() + + +def generate_csrf_token(user_id: str) -> str: + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + payload = { + "exp": int(exp_dt.timestamp()), + "sub": user_id, + } + return PassportService().issue(payload) diff --git a/api/services/account_service.py b/api/services/account_service.py index 106bc0e77e..cb0eb7a9dd 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -22,6 +22,7 @@ from libs.helper import RateLimiter, TokenManager from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair +from libs.token import generate_csrf_token from models.account import ( Account, AccountIntegrate, @@ -76,6 +77,7 @@ logger = logging.getLogger(__name__) class TokenPair(BaseModel): access_token: str refresh_token: str + csrf_token: str REFRESH_TOKEN_PREFIX = "refresh_token:" @@ -403,10 +405,11 @@ class AccountService: access_token = AccountService.get_account_jwt_token(account=account) refresh_token = _generate_refresh_token() + csrf_token = generate_csrf_token(account.id) AccountService._store_refresh_token(refresh_token, account.id) - return TokenPair(access_token=access_token, refresh_token=refresh_token) + return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token) @staticmethod def logout(*, account: Account): @@ -431,8 +434,9 @@ class AccountService: AccountService._delete_refresh_token(refresh_token, account.id) AccountService._store_refresh_token(new_refresh_token, account.id) + csrf_token = generate_csrf_token(account.id) - return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) + return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token) @staticmethod def load_logged_in_account(*, account_id: str): diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 4fbf33fd6f..974aa849db 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -46,17 +46,17 @@ class EnterpriseService: class WebAppAuth: @classmethod - def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): - params = {"userId": user_id, "appCode": app_code} + def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str): + params = {"userId": user_id, "appId": app_id} data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) return data.get("result", False) @classmethod - def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]): - if not app_codes: + def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_ids: list[str]): + if not app_ids: return {} - body = {"userId": user_id, "appCodes": app_codes} + body = {"userId": user_id, "appIds": app_ids} data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body) if not data: raise ValueError("No data found.") diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 693bfb95b6..9bd797a45f 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -172,7 +172,8 @@ class WebAppAuthService: return WebAppAuthType.EXTERNAL if app_code: - webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code) + app_id = AppService.get_app_id_by_code(app_code) + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) return cls.get_app_auth_type(access_mode=webapp_settings.access_mode) raise ValueError("Could not determine app authentication type.") diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 9fc16d9eb7..73e622b061 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -863,13 +863,14 @@ class TestWebAppAuthService: - Mock service integration """ # Arrange: Setup mock for enterprise service - mock_webapp_auth = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})() + mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id" + setting = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})() mock_external_service_dependencies[ "enterprise_service" - ].WebAppAuth.get_app_access_mode_by_code.return_value = mock_webapp_auth + ].WebAppAuth.get_app_access_mode_by_id.return_value = setting # Act: Execute authentication type determination - result = WebAppAuthService.get_app_auth_type(app_code="mock_app_code") + result: WebAppAuthType = WebAppAuthService.get_app_auth_type(app_code="mock_app_code") # Assert: Verify correct result assert result == WebAppAuthType.EXTERNAL @@ -877,7 +878,7 @@ class TestWebAppAuthService: # Verify mock service was called correctly mock_external_service_dependencies[ "enterprise_service" - ].WebAppAuth.get_app_access_mode_by_code.assert_called_once_with("mock_app_code") + ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id") def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): """ diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py index 67f4b85413..399caf8c4d 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py @@ -179,9 +179,7 @@ class TestOAuthCallback: oauth_setup["provider"].get_access_token.assert_called_once_with("test_code") oauth_setup["provider"].get_user_info.assert_called_once_with("access_token") - mock_redirect.assert_called_once_with( - "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token" - ) + mock_redirect.assert_called_once_with("http://localhost:3000") @pytest.mark.parametrize( ("exception", "expected_error"), @@ -224,8 +222,8 @@ class TestOAuthCallback: # CLOSED status: Currently NOT handled, will proceed to login (security issue) # This documents actual behavior. See test_defensive_check_for_closed_account_status for details ( - AccountStatus.CLOSED, - "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token", + AccountStatus.CLOSED.value, + "http://localhost:3000", ), ], ) @@ -268,6 +266,7 @@ class TestOAuthCallback: mock_token_pair = MagicMock() mock_token_pair.access_token = "jwt_access_token" mock_token_pair.refresh_token = "jwt_refresh_token" + mock_token_pair.csrf_token = "csrf_token" mock_account_service.login.return_value = mock_token_pair with app.test_request_context("/auth/oauth/github/callback?code=test_code"): @@ -299,6 +298,12 @@ class TestOAuthCallback: mock_account.status = AccountStatus.PENDING mock_generate_account.return_value = mock_account + mock_token_pair = MagicMock() + mock_token_pair.access_token = "jwt_access_token" + mock_token_pair.refresh_token = "jwt_refresh_token" + mock_token_pair.csrf_token = "csrf_token" + mock_account_service.login.return_value = mock_token_pair + with app.test_request_context("/auth/oauth/github/callback?code=test_code"): resource.get("github") @@ -361,6 +366,7 @@ class TestOAuthCallback: mock_token_pair = MagicMock() mock_token_pair.access_token = "jwt_access_token" mock_token_pair.refresh_token = "jwt_refresh_token" + mock_token_pair.csrf_token = "csrf_token" mock_account_service.login.return_value = mock_token_pair # Execute OAuth callback @@ -368,9 +374,7 @@ class TestOAuthCallback: resource.get("github") # Verify current behavior: login succeeds (this is NOT ideal) - mock_redirect.assert_called_once_with( - "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token" - ) + mock_redirect.assert_called_once_with("http://localhost:3000") mock_account_service.login.assert_called_once() # Document expected behavior in comments: diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py index a9edb913ea..c4c376a070 100644 --- a/api/tests/unit_tests/libs/test_external_api.py +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -2,7 +2,9 @@ from flask import Blueprint, Flask from flask_restx import Resource from werkzeug.exceptions import BadRequest, Unauthorized +from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN from core.errors.error import AppInvokeQuotaExceededError +from libs.exception import BaseHTTPException from libs.external_api import ExternalApi @@ -120,3 +122,66 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none(): assert res.status_code in (400, 429) finally: ext.sys.exc_info = orig_exc_info # type: ignore[assignment] + + +def test_unauthorized_and_force_logout_clears_cookies(): + """Test that UnauthorizedAndForceLogout error clears auth cookies""" + + class UnauthorizedAndForceLogout(BaseHTTPException): + error_code = "unauthorized_and_force_logout" + description = "Unauthorized and force logout." + code = 401 + + app = Flask(__name__) + bp = Blueprint("test", __name__) + api = ExternalApi(bp) + + @api.route("/force-logout") + class ForceLogout(Resource): # type: ignore + def get(self): # type: ignore + raise UnauthorizedAndForceLogout() + + app.register_blueprint(bp, url_prefix="/api") + client = app.test_client() + + # Set cookies first + client.set_cookie(COOKIE_NAME_ACCESS_TOKEN, "test_access_token") + client.set_cookie(COOKIE_NAME_CSRF_TOKEN, "test_csrf_token") + client.set_cookie(COOKIE_NAME_REFRESH_TOKEN, "test_refresh_token") + + # Make request that should trigger cookie clearing + res = client.get("/api/force-logout") + + # Verify response + assert res.status_code == 401 + data = res.get_json() + assert data["code"] == "unauthorized_and_force_logout" + assert data["status"] == 401 + assert "WWW-Authenticate" in res.headers + + # Verify Set-Cookie headers are present to clear cookies + set_cookie_headers = res.headers.getlist("Set-Cookie") + assert len(set_cookie_headers) == 3, f"Expected 3 Set-Cookie headers, got {len(set_cookie_headers)}" + + # Verify each cookie is being cleared (empty value and expired) + cookie_names_found = set() + for cookie_header in set_cookie_headers: + # Check for cookie names + if COOKIE_NAME_ACCESS_TOKEN in cookie_header: + cookie_names_found.add(COOKIE_NAME_ACCESS_TOKEN) + assert '""' in cookie_header or "=" in cookie_header # Empty value + assert "Expires=Thu, 01 Jan 1970" in cookie_header # Expired + elif COOKIE_NAME_CSRF_TOKEN in cookie_header: + cookie_names_found.add(COOKIE_NAME_CSRF_TOKEN) + assert '""' in cookie_header or "=" in cookie_header + assert "Expires=Thu, 01 Jan 1970" in cookie_header + elif COOKIE_NAME_REFRESH_TOKEN in cookie_header: + cookie_names_found.add(COOKIE_NAME_REFRESH_TOKEN) + assert '""' in cookie_header or "=" in cookie_header + assert "Expires=Thu, 01 Jan 1970" in cookie_header + + # Verify all three cookies are present + assert len(cookie_names_found) == 3 + assert COOKIE_NAME_ACCESS_TOKEN in cookie_names_found + assert COOKIE_NAME_CSRF_TOKEN in cookie_names_found + assert COOKIE_NAME_REFRESH_TOKEN in cookie_names_found diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index 39671077d4..35155b4931 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -19,10 +19,15 @@ class MockUser(UserMixin): return self._is_authenticated +def mock_csrf_check(*args, **kwargs): + return + + class TestLoginRequired: """Test cases for login_required decorator.""" @pytest.fixture + @patch("libs.login.check_csrf_token", mock_csrf_check) def setup_app(self, app: Flask): """Set up Flask app with login manager.""" # Initialize login manager @@ -39,6 +44,7 @@ class TestLoginRequired: return app + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_authenticated_user_can_access_protected_view(self, setup_app: Flask): """Test that authenticated users can access protected views.""" @@ -53,6 +59,7 @@ class TestLoginRequired: result = protected_view() assert result == "Protected content" + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask): """Test that unauthenticated users are redirected.""" @@ -68,6 +75,7 @@ class TestLoginRequired: assert result == "Unauthorized" setup_app.login_manager.unauthorized.assert_called_once() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask): """Test that LOGIN_DISABLED config bypasses authentication.""" @@ -87,6 +95,7 @@ class TestLoginRequired: # Ensure unauthorized was not called setup_app.login_manager.unauthorized.assert_not_called() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_options_request_bypasses_authentication(self, setup_app: Flask): """Test that OPTIONS requests are exempt from authentication.""" @@ -103,6 +112,7 @@ class TestLoginRequired: # Ensure unauthorized was not called setup_app.login_manager.unauthorized.assert_not_called() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_flask_2_compatibility(self, setup_app: Flask): """Test Flask 2.x compatibility with ensure_sync.""" @@ -120,6 +130,7 @@ class TestLoginRequired: assert result == "Synced content" setup_app.ensure_sync.assert_called_once() + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_flask_1_compatibility(self, setup_app: Flask): """Test Flask 1.x compatibility without ensure_sync.""" diff --git a/api/tests/unit_tests/libs/test_token.py b/api/tests/unit_tests/libs/test_token.py new file mode 100644 index 0000000000..22790fa4a6 --- /dev/null +++ b/api/tests/unit_tests/libs/test_token.py @@ -0,0 +1,23 @@ +from constants import COOKIE_NAME_ACCESS_TOKEN +from libs.token import extract_access_token + + +class MockRequest: + def __init__(self, headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]): + self.headers: dict[str, str] = headers + self.cookies: dict[str, str] = cookies + self.args: dict[str, str] = args + + +def test_extract_access_token(): + def _mock_request(headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]): + return MockRequest(headers, cookies, args) + + test_cases = [ + (_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123"), + (_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123"), + (_mock_request({}, {}, {}), None), + (_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None), + ] + for request, expected in test_cases: + assert extract_access_token(request) == expected # pyright: ignore[reportArgumentType] diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index e3cfc8e6a8..2185606a6d 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -2,16 +2,17 @@ import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' -import { removeAccessToken } from '@/app/components/share/utils' import { useWebAppStore } from '@/context/web-app-context' import { useGetUserCanAccessApp } from '@/service/access-control' import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share' +import { webAppLogout } from '@/service/webapp-auth' import { usePathname, useRouter, useSearchParams } from 'next/navigation' import React, { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { const { t } = useTranslation() + const shareCode = useWebAppStore(s => s.shareCode) const updateAppInfo = useWebAppStore(s => s.updateAppInfo) const updateAppParams = useWebAppStore(s => s.updateAppParams) const updateWebAppMeta = useWebAppStore(s => s.updateWebAppMeta) @@ -41,11 +42,11 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { return `/webapp-signin?${params.toString()}` }, [searchParams, pathname]) - const backToHome = useCallback(() => { - removeAccessToken() + const backToHome = useCallback(async () => { + await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router]) + }, [getSigninUrl, router, webAppLogout, shareCode]) if (appInfoError) { return
diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index 4fe9efe4dd..c26ea7e045 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -1,15 +1,16 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import { useEffect } from 'react' +import { useEffect, useState } from 'react' import { useCallback } from 'react' import { useWebAppStore } from '@/context/web-app-context' import { useRouter, useSearchParams } from 'next/navigation' import AppUnavailable from '@/app/components/base/app-unavailable' -import { checkOrSetAccessToken, removeAccessToken, setAccessToken } from '@/app/components/share/utils' import { useTranslation } from 'react-i18next' +import { AccessMode } from '@/models/access-control' +import { webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' import { fetchAccessToken } from '@/service/share' import Loading from '@/app/components/base/loading' -import { AccessMode } from '@/models/access-control' +import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' const Splash: FC = ({ children }) => { const { t } = useTranslation() @@ -18,9 +19,9 @@ const Splash: FC = ({ children }) => { const searchParams = useSearchParams() const router = useRouter() const redirectUrl = searchParams.get('redirect_url') - const tokenFromUrl = searchParams.get('web_sso_token') const message = searchParams.get('message') const code = searchParams.get('code') + const tokenFromUrl = searchParams.get('web_sso_token') const getSigninUrl = useCallback(() => { const params = new URLSearchParams(searchParams) params.delete('message') @@ -28,35 +29,66 @@ const Splash: FC = ({ children }) => { return `/webapp-signin?${params.toString()}` }, [searchParams]) - const backToHome = useCallback(() => { - removeAccessToken() + const backToHome = useCallback(async () => { + await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router]) + }, [getSigninUrl, router, webAppLogout, shareCode]) + const needCheckIsLogin = webAppAccessMode !== AccessMode.PUBLIC + const [isLoading, setIsLoading] = useState(true) useEffect(() => { + if (message) { + setIsLoading(false) + return + } + + if(tokenFromUrl) + setWebAppAccessToken(tokenFromUrl) + + const redirectOrFinish = () => { + if (redirectUrl) + router.replace(decodeURIComponent(redirectUrl)) + else + setIsLoading(false) + } + + const proceedToAuth = () => { + setIsLoading(false) + } + (async () => { - if (message) - return - if (shareCode && tokenFromUrl && redirectUrl) { - localStorage.setItem('webapp_access_token', tokenFromUrl) - const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: tokenFromUrl }) - await setAccessToken(shareCode, tokenResp.access_token) - router.replace(decodeURIComponent(redirectUrl)) - return + const { userLoggedIn, appLoggedIn } = await webAppLoginStatus(needCheckIsLogin, shareCode!) + + if (userLoggedIn && appLoggedIn) { + redirectOrFinish() } - if (shareCode && redirectUrl && localStorage.getItem('webapp_access_token')) { - const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: localStorage.getItem('webapp_access_token') }) - await setAccessToken(shareCode, tokenResp.access_token) - router.replace(decodeURIComponent(redirectUrl)) - return + else if (!userLoggedIn && !appLoggedIn) { + proceedToAuth() } - if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) { - await checkOrSetAccessToken(shareCode) - router.replace(decodeURIComponent(redirectUrl)) + else if (!userLoggedIn && appLoggedIn) { + redirectOrFinish() + } + else if (userLoggedIn && !appLoggedIn) { + try { + const { access_token } = await fetchAccessToken({ appCode: shareCode! }) + setWebAppPassport(shareCode!, access_token) + redirectOrFinish() + } + catch (error) { + await webAppLogout(shareCode!) + proceedToAuth() + } } })() - }, [shareCode, redirectUrl, router, tokenFromUrl, message, webAppAccessMode]) + }, [ + shareCode, + redirectUrl, + router, + message, + webAppAccessMode, + needCheckIsLogin, + tokenFromUrl]) if (message) { return
@@ -64,12 +96,8 @@ const Splash: FC = ({ children }) => { {code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')}
} - if (tokenFromUrl) { - return
- -
- } - if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) { + + if (isLoading) { return
diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index 3fc32fec71..4a1326fedf 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -10,7 +10,7 @@ import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import I18NContext from '@/context/i18n' -import { setAccessToken } from '@/app/components/share/utils' +import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' import { fetchAccessToken } from '@/service/share' export default function CheckCode() { @@ -62,9 +62,9 @@ export default function CheckCode() { setIsLoading(true) const ret = await webAppEmailLoginWithCode({ email, code, token }) if (ret.result === 'success') { - localStorage.setItem('webapp_access_token', ret.data.access_token) - const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token }) - await setAccessToken(appCode, tokenResp.access_token) + setWebAppAccessToken(ret.data.access_token) + const { access_token } = await fetchAccessToken({ appCode: appCode! }) + setWebAppPassport(appCode!, access_token) router.replace(decodeURIComponent(redirectUrl)) } } diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 2b6bd73df0..ce220b103e 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -11,15 +11,13 @@ import { webAppLogin } from '@/service/common' import Input from '@/app/components/base/input' import I18NContext from '@/context/i18n' import { noop } from 'lodash-es' -import { setAccessToken } from '@/app/components/share/utils' import { fetchAccessToken } from '@/service/share' +import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' type MailAndPasswordAuthProps = { isEmailSetup: boolean } -const passwordRegex = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ - export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) { const { t } = useTranslation() const { locale } = useContext(I18NContext) @@ -43,8 +41,8 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut return appCode }, [redirectUrl]) + const appCode = getAppCodeFromRedirectUrl() const handleEmailPasswordLogin = async () => { - const appCode = getAppCodeFromRedirectUrl() if (!email) { Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) return @@ -60,13 +58,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') }) return } - if (!passwordRegex.test(password)) { - Toast.notify({ - type: 'error', - message: t('login.error.passwordInvalid'), - }) - return - } + if (!redirectUrl || !appCode) { Toast.notify({ type: 'error', @@ -88,9 +80,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut body: loginData, }) if (res.result === 'success') { - localStorage.setItem('webapp_access_token', res.data.access_token) - const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: res.data.access_token }) - await setAccessToken(appCode, tokenResp.access_token) + setWebAppAccessToken(res.data.access_token) + + const { access_token } = await fetchAccessToken({ appCode: appCode! }) + setWebAppPassport(appCode!, access_token) router.replace(decodeURIComponent(redirectUrl)) } else { @@ -141,9 +134,9 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
setPassword(e.target.value)} + id="password" onKeyDown={(e) => { if (e.key === 'Enter') handleEmailPasswordLogin() diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index 1c6209b902..2ffa19c0c9 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -3,13 +3,13 @@ import { useRouter, useSearchParams } from 'next/navigation' import type { FC } from 'react' import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { removeAccessToken } from '@/app/components/share/utils' import { useGlobalPublicStore } from '@/context/global-public-context' import AppUnavailable from '@/app/components/base/app-unavailable' import NormalForm from './normalForm' import { AccessMode } from '@/models/access-control' import ExternalMemberSsoAuth from './components/external-member-sso-auth' import { useWebAppStore } from '@/context/web-app-context' +import { webAppLogout } from '@/service/webapp-auth' const WebSSOForm: FC = () => { const { t } = useTranslation() @@ -26,11 +26,12 @@ const WebSSOForm: FC = () => { return `/webapp-signin?${params.toString()}` }, [redirectUrl]) - const backToHome = useCallback(() => { - removeAccessToken() + const shareCode = useWebAppStore(s => s.shareCode) + const backToHome = useCallback(async () => { + await webAppLogout(shareCode!) const url = getSigninUrl() router.replace(url) - }, [getSigninUrl, router]) + }, [getSigninUrl, router, webAppLogout, shareCode]) if (!redirectUrl) { return
diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index bd00f27ac5..d04cd18557 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -9,7 +9,6 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import { checkEmailExisted, - logout, resetEmail, sendVerifyCode, verifyEmail, @@ -17,6 +16,7 @@ import { import { noop } from 'lodash-es' import { asyncRunSafe } from '@/utils' import type { ResponseError } from '@/service/fetch' +import { useLogout } from '@/service/use-common' type Props = { show: boolean @@ -167,15 +167,12 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { setStep(STEP.verifyNew) } + const { mutateAsync: logout } = useLogout() const handleLogout = async () => { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend router.push('/signin') } diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index ea897e639f..d8943b7879 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -7,11 +7,11 @@ import { } from '@remixicon/react' import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react' import Avatar from '@/app/components/base/avatar' -import { logout } from '@/service/common' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import PremiumBadge from '@/app/components/base/premium-badge' +import { useLogout } from '@/service/use-common' export type IAppSelector = { isMobile: boolean @@ -23,15 +23,12 @@ export default function AppSelector() { const { userProfile } = useAppContext() const { isEducationAccount } = useProviderContext() + const { mutateAsync: logout } = useLogout() const handleLogout = async () => { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend router.push('/signin') } diff --git a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx index 2cd30bc3f2..64a378d2fe 100644 --- a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx @@ -8,7 +8,7 @@ import Button from '@/app/components/base/button' import CustomDialog from '@/app/components/base/dialog' import Textarea from '@/app/components/base/textarea' import Toast from '@/app/components/base/toast' -import { logout } from '@/service/common' +import { useLogout } from '@/service/use-common' type DeleteAccountProps = { onCancel: () => void @@ -22,14 +22,11 @@ export default function FeedBack(props: DeleteAccountProps) { const [userFeedback, setUserFeedback] = useState('') const { isPending, mutateAsync: sendFeedback } = useDeleteAccountFeedback() + const { mutateAsync: logout } = useLogout() const handleSuccess = useCallback(async () => { try { - await logout({ - url: '/logout', - params: {}, - }) - localStorage.removeItem('refresh_token') - localStorage.removeItem('console_token') + await logout() + // Tokens are now stored in cookies and cleared by backend router.push('/signin') Toast.notify({ type: 'info', message: t('common.account.deleteSuccessTip') }) } diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 078d23114a..2ab676d6b6 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -5,17 +5,22 @@ import cn from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { AppContextProvider } from '@/context/app-context' -import { useMemo } from 'react' +import { useIsLogin } from '@/service/use-common' +import Loading from '@/app/components/base/loading' export default function SignInLayout({ children }: any) { const { systemFeatures } = useGlobalPublicStore() useDocumentTitle('') - const isLoggedIn = useMemo(() => { - try { - return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token')) - } - catch { return false } - }, []) + const { isLoading, data: loginData } = useIsLogin() + const isLoggedIn = loginData?.logged_in + + if(isLoading) { + return ( +
+ +
+ ) + } return <>
diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 6ad63996ae..4aa5fa0b8e 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -1,6 +1,6 @@ 'use client' -import React, { useEffect, useMemo, useRef } from 'react' +import React, { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' import { useRouter, useSearchParams } from 'next/navigation' import Button from '@/app/components/base/button' @@ -18,6 +18,7 @@ import { RiTranslate2, } from '@remixicon/react' import dayjs from 'dayjs' +import { useIsLogin } from '@/service/use-common' export const OAUTH_AUTHORIZE_PENDING_KEY = 'oauth_authorize_pending' export const REDIRECT_URL_KEY = 'oauth_redirect_url' @@ -74,17 +75,13 @@ export default function OAuthAuthorize() { const client_id = decodeURIComponent(searchParams.get('client_id') || '') const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '') const { userProfile } = useAppContext() - const { data: authAppInfo, isLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) + const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp() const hasNotifiedRef = useRef(false) - const isLoggedIn = useMemo(() => { - try { - return Boolean(localStorage.getItem('console_token') && localStorage.getItem('refresh_token')) - } - catch { return false } - }, []) - + const { isLoading: isIsLoginLoading, data: loginData } = useIsLogin() + const isLoggedIn = loginData?.logged_in + const isLoading = isOAuthLoading || isIsLoginLoading const onLoginSwitchClick = () => { try { const returnUrl = buildReturnUrl('/account/oauth/authorize', `?client_id=${encodeURIComponent(client_id)}&redirect_uri=${encodeURIComponent(redirect_uri)}`) diff --git a/web/app/components/app/app-access-control/access-control-dialog.tsx b/web/app/components/app/app-access-control/access-control-dialog.tsx index 479eedc9cf..ee3fa9650b 100644 --- a/web/app/components/app/app-access-control/access-control-dialog.tsx +++ b/web/app/components/app/app-access-control/access-control-dialog.tsx @@ -22,7 +22,7 @@ const AccessControlDialog = ({ }, [onClose]) return ( - null}> + null}> -
+
diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index 0fad6cc740..e9519aeedf 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -52,7 +52,7 @@ export default function AddMemberOrGroupDialog() { {open && } - +
diff --git a/web/app/components/base/chat/chat-with-history/index.tsx b/web/app/components/base/chat/chat-with-history/index.tsx index 464e30a821..6953be4b3c 100644 --- a/web/app/components/base/chat/chat-with-history/index.tsx +++ b/web/app/components/base/chat/chat-with-history/index.tsx @@ -4,7 +4,6 @@ import { useEffect, useState, } from 'react' -import { useAsyncEffect } from 'ahooks' import { useThemeContext } from '../embedded-chatbot/theme/theme-context' import { ChatWithHistoryContext, @@ -18,8 +17,6 @@ import ChatWrapper from './chat-wrapper' import type { InstalledApp } from '@/models/explore' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -import { checkOrSetAccessToken } from '@/app/components/share/utils' -import AppUnavailable from '@/app/components/base/app-unavailable' import cn from '@/utils/classnames' import useDocumentTitle from '@/hooks/use-document-title' @@ -201,36 +198,6 @@ const ChatWithHistoryWrapWithCheckToken: FC = ({ installedAppInfo, className, }) => { - const [initialized, setInitialized] = useState(false) - const [appUnavailable, setAppUnavailable] = useState(false) - const [isUnknownReason, setIsUnknownReason] = useState(false) - - useAsyncEffect(async () => { - if (!initialized) { - if (!installedAppInfo) { - try { - await checkOrSetAccessToken() - } - catch (e: any) { - if (e.status === 404) { - setAppUnavailable(true) - } - else { - setIsUnknownReason(true) - setAppUnavailable(true) - } - } - } - setInitialized(true) - } - }, []) - - if (!initialized) - return null - - if (appUnavailable) - return - return ( { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend // To avoid use other account's education notice info localStorage.removeItem('education-reverify-prev-expire-at') diff --git a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts index ad757f36a7..51782f3cbf 100644 --- a/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/rag-pipeline/hooks/use-nodes-sync-draft.ts @@ -77,7 +77,7 @@ export const useNodesSyncDraft = () => { if (postParams) { navigator.sendBeacon( - `${API_PREFIX}${postParams.url}?_token=${localStorage.getItem('console_token')}`, + `${API_PREFIX}${postParams.url}`, JSON.stringify(postParams.params), ) } diff --git a/web/app/components/share/text-generation/menu-dropdown.tsx b/web/app/components/share/text-generation/menu-dropdown.tsx index 373e3b8699..e3b12b3d84 100644 --- a/web/app/components/share/text-generation/menu-dropdown.tsx +++ b/web/app/components/share/text-generation/menu-dropdown.tsx @@ -20,6 +20,7 @@ import type { SiteInfo } from '@/models/share' import cn from '@/utils/classnames' import { AccessMode } from '@/models/access-control' import { useWebAppStore } from '@/context/web-app-context' +import { webAppLogout } from '@/service/webapp-auth' type Props = { data?: SiteInfo @@ -49,11 +50,11 @@ const MenuDropdown: FC = ({ setOpen(!openRef.current) }, [setOpen]) - const handleLogout = useCallback(() => { - localStorage.removeItem('token') - localStorage.removeItem('webapp_access_token') + const shareCode = useWebAppStore(s => s.shareCode) + const handleLogout = useCallback(async () => { + await webAppLogout(shareCode!) router.replace(`/webapp-signin?redirect_url=${pathname}`) - }, [router, pathname]) + }, [router, pathname, webAppLogout, shareCode]) const [show, setShow] = useState(false) diff --git a/web/app/components/share/utils.ts b/web/app/components/share/utils.ts index 3f5303dfcc..491433322d 100644 --- a/web/app/components/share/utils.ts +++ b/web/app/components/share/utils.ts @@ -1,7 +1,3 @@ -import { CONVERSATION_ID_INFO } from '../base/chat/constants' -import { fetchAccessToken } from '@/service/share' -import { getProcessedSystemVariablesFromUrlParams } from '../base/chat/utils' - export const isTokenV1 = (token: Record) => { return !token.version } @@ -9,55 +5,3 @@ export const isTokenV1 = (token: Record) => { export const getInitialTokenV2 = (): Record => ({ version: 2, }) - -export const checkOrSetAccessToken = async (appCode?: string | null) => { - const sharedToken = appCode || globalThis.location.pathname.split('/').slice(-1)[0] - const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id - const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2()) - let accessTokenJson = getInitialTokenV2() - try { - accessTokenJson = JSON.parse(accessToken) - if (isTokenV1(accessTokenJson)) - accessTokenJson = getInitialTokenV2() - } - catch { - - } - - if (!accessTokenJson[sharedToken]?.[userId || 'DEFAULT']) { - const webAppAccessToken = localStorage.getItem('webapp_access_token') - const res = await fetchAccessToken({ appCode: sharedToken, userId, webAppAccessToken }) - accessTokenJson[sharedToken] = { - ...accessTokenJson[sharedToken], - [userId || 'DEFAULT']: res.access_token, - } - localStorage.setItem('token', JSON.stringify(accessTokenJson)) - localStorage.removeItem(CONVERSATION_ID_INFO) - } -} - -export const setAccessToken = (sharedToken: string, token: string, user_id?: string) => { - const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2()) - let accessTokenJson = getInitialTokenV2() - try { - accessTokenJson = JSON.parse(accessToken) - if (isTokenV1(accessTokenJson)) - accessTokenJson = getInitialTokenV2() - } - catch { - - } - - localStorage.removeItem(CONVERSATION_ID_INFO) - - accessTokenJson[sharedToken] = { - ...accessTokenJson[sharedToken], - [user_id || 'DEFAULT']: token, - } - localStorage.setItem('token', JSON.stringify(accessTokenJson)) -} - -export const removeAccessToken = () => { - localStorage.removeItem('token') - localStorage.removeItem('webapp_access_token') -} diff --git a/web/app/components/swr-initializer.tsx b/web/app/components/swr-initializer.tsx index fd9432fdd8..1ab1567659 100644 --- a/web/app/components/swr-initializer.tsx +++ b/web/app/components/swr-initializer.tsx @@ -19,10 +19,7 @@ const SwrInitializer = ({ }: SwrInitializerProps) => { const router = useRouter() const searchParams = useSearchParams() - const consoleToken = decodeURIComponent(searchParams.get('access_token') || '') - const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '') - const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') - const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') + // Tokens are now stored in cookies, no need to check localStorage const pathname = usePathname() const [init, setInit] = useState(false) @@ -57,21 +54,12 @@ const SwrInitializer = ({ router.replace('/install') return } - if (!((consoleToken && refreshToken) || (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage))) { - router.replace('/signin') - return - } - if (searchParams.has('access_token') || searchParams.has('refresh_token')) { - if (consoleToken) - localStorage.setItem('console_token', consoleToken) - if (refreshToken) - localStorage.setItem('refresh_token', refreshToken) - const redirectUrl = resolvePostLoginRedirect(searchParams) - if (redirectUrl) - location.replace(redirectUrl) - else - router.replace(pathname) - } + + const redirectUrl = resolvePostLoginRedirect(searchParams) + if (redirectUrl) + location.replace(redirectUrl) + else + router.replace(pathname) setInit(true) } @@ -79,7 +67,7 @@ const SwrInitializer = ({ router.replace('/signin') } })() - }, [isSetupFinished, router, pathname, searchParams, consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage]) + }, [isSetupFinished, router, pathname, searchParams]) return init ? ( diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 5705deb0c0..d33bfcc8b8 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -97,7 +97,7 @@ export const useNodesSyncDraft = () => { if (postParams) { navigator.sendBeacon( - `${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`, + `${API_PREFIX}/apps/${params.appId}/workflows/draft`, JSON.stringify(postParams.params), ) } diff --git a/web/app/education-apply/user-info.tsx b/web/app/education-apply/user-info.tsx index e1d60a5e94..96ff1aaae6 100644 --- a/web/app/education-apply/user-info.tsx +++ b/web/app/education-apply/user-info.tsx @@ -2,24 +2,21 @@ import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' import Button from '@/app/components/base/button' import { useAppContext } from '@/context/app-context' -import { logout } from '@/service/common' import Avatar from '@/app/components/base/avatar' import { Triangle } from '@/app/components/base/icons/src/public/education' +import { useLogout } from '@/service/use-common' const UserInfo = () => { const router = useRouter() const { t } = useTranslation() const { userProfile } = useAppContext() + const { mutateAsync: logout } = useLogout() const handleLogout = async () => { - await logout({ - url: '/logout', - params: {}, - }) + await logout() localStorage.removeItem('setup_status') - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Tokens are now stored in cookies and cleared by backend router.push('/signin') } diff --git a/web/app/install/installForm.tsx b/web/app/install/installForm.tsx index 65d1998fcc..0a534b72fe 100644 --- a/web/app/install/installForm.tsx +++ b/web/app/install/installForm.tsx @@ -72,8 +72,6 @@ const InstallForm = () => { // Store tokens and redirect to apps if login successful if (loginRes.result === 'success') { - localStorage.setItem('console_token', loginRes.data.access_token) - localStorage.setItem('refresh_token', loginRes.data.refresh_token) router.replace('/apps') } else { diff --git a/web/app/signin/check-code/page.tsx b/web/app/signin/check-code/page.tsx index 8f12d807db..da6bd426af 100644 --- a/web/app/signin/check-code/page.tsx +++ b/web/app/signin/check-code/page.tsx @@ -42,8 +42,6 @@ export default function CheckCode() { setIsLoading(true) const ret = await emailLoginWithCode({ email, code, token }) if (ret.result === 'success') { - localStorage.setItem('console_token', ret.data.access_token) - localStorage.setItem('refresh_token', ret.data.refresh_token) if (invite_token) { router.replace(`/signin/invite-settings?${searchParams.toString()}`) } diff --git a/web/app/signin/components/mail-and-password-auth.tsx b/web/app/signin/components/mail-and-password-auth.tsx index 5214b73ee0..2740a82782 100644 --- a/web/app/signin/components/mail-and-password-auth.tsx +++ b/web/app/signin/components/mail-and-password-auth.tsx @@ -30,6 +30,7 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis const [password, setPassword] = useState('') const [isLoading, setIsLoading] = useState(false) + const handleEmailPasswordLogin = async () => { if (!email) { Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) @@ -66,8 +67,6 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis router.replace(`/signin/invite-settings?${searchParams.toString()}`) } else { - localStorage.setItem('console_token', res.data.access_token) - localStorage.setItem('refresh_token', res.data.refresh_token) const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') } diff --git a/web/app/signin/invite-settings/page.tsx b/web/app/signin/invite-settings/page.tsx index cec51a70ef..cbd37f51f6 100644 --- a/web/app/signin/invite-settings/page.tsx +++ b/web/app/signin/invite-settings/page.tsx @@ -58,8 +58,7 @@ export default function InviteSettingsPage() { }, }) if (res.result === 'success') { - localStorage.setItem('console_token', res.data.access_token) - localStorage.setItem('refresh_token', res.data.refresh_token) + // Tokens are now stored in cookies by the backend await setLocaleOnClient(language, false) const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') diff --git a/web/app/signin/normal-form.tsx b/web/app/signin/normal-form.tsx index a5a30a0cdd..920a992b4f 100644 --- a/web/app/signin/normal-form.tsx +++ b/web/app/signin/normal-form.tsx @@ -16,16 +16,18 @@ import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' import { resolvePostLoginRedirect } from './utils/post-login-redirect' import Split from './split' +import { useIsLogin } from '@/service/use-common' const NormalForm = () => { const { t } = useTranslation() const router = useRouter() const searchParams = useSearchParams() - const consoleToken = decodeURIComponent(searchParams.get('access_token') || '') - const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '') + const { isLoading: isCheckLoading, data: loginData } = useIsLogin() + const isLoggedIn = loginData?.logged_in const message = decodeURIComponent(searchParams.get('message') || '') const invite_token = decodeURIComponent(searchParams.get('invite_token') || '') - const [isLoading, setIsLoading] = useState(true) + const [isInitCheckLoading, setInitCheckLoading] = useState(true) + const isLoading = isCheckLoading || loginData?.logged_in || isInitCheckLoading const { systemFeatures } = useGlobalPublicStore() const [authType, updateAuthType] = useState<'code' | 'password'>('password') const [showORLine, setShowORLine] = useState(false) @@ -36,9 +38,7 @@ const NormalForm = () => { const init = useCallback(async () => { try { - if (consoleToken && refreshToken) { - localStorage.setItem('console_token', consoleToken) - localStorage.setItem('refresh_token', refreshToken) + if (isLoggedIn) { const redirectUrl = resolvePostLoginRedirect(searchParams) router.replace(redirectUrl || '/apps') return @@ -67,12 +67,12 @@ const NormalForm = () => { console.error(error) setAllMethodsAreDisabled(true) } - finally { setIsLoading(false) } - }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink, systemFeatures]) + finally { setInitCheckLoading(false) } + }, [isLoggedIn, message, router, invite_token, isInviteLink, systemFeatures]) useEffect(() => { init() }, [init]) - if (isLoading || consoleToken) { + if (isLoading) { return
{ new_password: password, password_confirm: confirmPassword, }) - const { result, data } = res as MailRegisterResponse + const { result } = res as MailRegisterResponse if (result === 'success') { Toast.notify({ type: 'success', message: t('common.api.actionSuccess'), }) - localStorage.setItem('console_token', data.access_token) - localStorage.setItem('refresh_token', data.refresh_token) router.replace('/apps') } } diff --git a/web/config/index.ts b/web/config/index.ts index f818a1c0af..0e876b800e 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -144,6 +144,17 @@ export const getMaxToken = (modelId: string) => { export const LOCALE_COOKIE_NAME = 'locale' +export const CSRF_COOKIE_NAME = () => { + const isSecure = API_PREFIX.startsWith('https://') + return isSecure ? '__Host-csrf_token' : 'csrf_token' +} +export const CSRF_HEADER_NAME = 'X-CSRF-Token' +export const ACCESS_TOKEN_LOCAL_STORAGE_NAME = 'access_token' +export const PASSPORT_LOCAL_STORAGE_NAME = (appCode: string) => `passport-${appCode}` +export const PASSPORT_HEADER_NAME = 'X-App-Passport' + +export const WEB_APP_SHARE_CODE_HEADER_NAME = 'X-App-Code' + export const DEFAULT_VALUE_MAX_LEN = 48 export const DEFAULT_PARAGRAPH_VALUE_MAX_LEN = 1000 diff --git a/web/context/web-app-context.tsx b/web/context/web-app-context.tsx index 0fe1b56b0a..48de01f2df 100644 --- a/web/context/web-app-context.tsx +++ b/web/context/web-app-context.tsx @@ -2,14 +2,12 @@ import type { ChatConfig } from '@/app/components/base/chat/types' import Loading from '@/app/components/base/loading' -import { checkOrSetAccessToken } from '@/app/components/share/utils' import { AccessMode } from '@/models/access-control' import type { AppData, AppMeta } from '@/models/share' import { useGetWebAppAccessModeByCode } from '@/service/use-share' import { usePathname, useSearchParams } from 'next/navigation' import type { FC, PropsWithChildren } from 'react' import { useEffect } from 'react' -import { useState } from 'react' import { create } from 'zustand' import { useGlobalPublicStore } from './global-public-context' @@ -71,24 +69,13 @@ const WebAppStoreProvider: FC = ({ children }) => { }, [shareCode, updateShareCode]) const { isFetching, data: accessModeResult } = useGetWebAppAccessModeByCode(shareCode) - const [isFetchingAccessToken, setIsFetchingAccessToken] = useState(true) useEffect(() => { - if (accessModeResult?.accessMode) { + if (accessModeResult?.accessMode) updateWebAppAccessMode(accessModeResult.accessMode) - if (accessModeResult.accessMode === AccessMode.PUBLIC) { - setIsFetchingAccessToken(true) - checkOrSetAccessToken(shareCode).finally(() => { - setIsFetchingAccessToken(false) - }) - } - else { - setIsFetchingAccessToken(false) - } - } }, [accessModeResult, updateWebAppAccessMode, shareCode]) - if (isGlobalPending || isFetching || isFetchingAccessToken) { + if (isGlobalPending || isFetching) { return
diff --git a/web/models/app.ts b/web/models/app.ts index 630dba9c19..26e6cba85b 100644 --- a/web/models/app.ts +++ b/web/models/app.ts @@ -2,63 +2,6 @@ import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikCo import type { App, AppMode, AppTemplate, SiteConfig } from '@/types/app' import type { Dependency } from '@/app/components/plugins/types' -/* export type App = { - id: string - name: string - description: string - mode: AppMode - enable_site: boolean - enable_api: boolean - api_rpm: number - api_rph: number - is_demo: boolean - model_config: AppModelConfig - providers: Array<{ provider: string; token_is_set: boolean }> - site: SiteConfig - created_at: string -} - -export type AppModelConfig = { - provider: string - model_id: string - configs: { - prompt_template: string - prompt_variables: Array - completion_params: CompletionParam - } -} - -export type PromptVariable = { - key: string - name: string - description: string - type: string | number - default: string - options: string[] -} - -export type CompletionParam = { - max_tokens: number - temperature: number - top_p: number - echo: boolean - stop: string[] - presence_penalty: number - frequency_penalty: number -} - -export type SiteConfig = { - access_token: string - title: string - author: string - support_email: string - default_language: string - customize_domain: string - theme: string - customize_token_strategy: 'must' | 'allow' | 'not_allow' - prompt_public: boolean -} */ - export enum DSLImportMode { YAML_CONTENT = 'yaml-content', YAML_URL = 'yaml-url', diff --git a/web/service/base.ts b/web/service/base.ts index 1cb99e38d3..6e54e228e1 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -1,4 +1,4 @@ -import { API_PREFIX, IS_CE_EDITION, PUBLIC_API_PREFIX } from '@/config' +import { API_PREFIX, CSRF_COOKIE_NAME, CSRF_HEADER_NAME, IS_CE_EDITION, PASSPORT_HEADER_NAME, PUBLIC_API_PREFIX, WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config' import { refreshAccessTokenOrRelogin } from './refresh-token' import Toast from '@/app/components/base/toast' import { basePath } from '@/utils/var' @@ -21,15 +21,16 @@ import type { WorkflowFinishedResponse, WorkflowStartedResponse, } from '@/types/workflow' -import { removeAccessToken } from '@/app/components/share/utils' import type { FetchOptionType, ResponseError } from './fetch' -import { ContentType, base, getAccessToken, getBaseOptions } from './fetch' +import { ContentType, base, getBaseOptions } from './fetch' import { asyncRunSafe } from '@/utils' import type { DataSourceNodeCompletedResponse, DataSourceNodeErrorResponse, DataSourceNodeProcessingResponse, } from '@/types/pipeline' +import Cookies from 'js-cookie' +import { getWebAppPassport } from './webapp-auth' const TIME_OUT = 100000 export type IOnDataMoreInfo = { @@ -122,14 +123,19 @@ function unicodeToChar(text: string) { }) } +const WBB_APP_LOGIN_PATH = '/webapp-signin' function requiredWebSSOLogin(message?: string, code?: number) { const params = new URLSearchParams() + // prevent redirect loop + if(globalThis.location.pathname === WBB_APP_LOGIN_PATH) + return + params.append('redirect_url', encodeURIComponent(`${globalThis.location.pathname}${globalThis.location.search}`)) if (message) params.append('message', message) if (code) params.append('code', String(code)) - globalThis.location.href = `${globalThis.location.origin}${basePath}/webapp-signin?${params.toString()}` + globalThis.location.href = `${globalThis.location.origin}${basePath}/${WBB_APP_LOGIN_PATH}?${params.toString()}` } export function format(text: string) { @@ -338,12 +344,14 @@ type UploadResponse = { export const upload = async (options: UploadOptions, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise => { const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX - const token = await getAccessToken(isPublicAPI) + const shareCode = globalThis.location.pathname.split('/').slice(-1)[0] const defaultOptions = { method: 'POST', url: (url ? `${urlPrefix}${url}` : `${urlPrefix}/files/upload`) + (searchParams || ''), headers: { - Authorization: `Bearer ${token}`, + [CSRF_HEADER_NAME]: Cookies.get(CSRF_COOKIE_NAME()) || '', + [PASSPORT_HEADER_NAME]: getWebAppPassport(shareCode), + [WEB_APP_SHARE_CODE_HEADER_NAME]: shareCode, }, } const mergedOptions = { @@ -413,14 +421,17 @@ export const ssePost = async ( } = otherOptions const abortController = new AbortController() - const token = localStorage.getItem('console_token') + // No need to get token from localStorage, cookies will be sent automatically const baseOptions = getBaseOptions() + const shareCode = globalThis.location.pathname.split('/').slice(-1)[0] const options = Object.assign({}, baseOptions, { method: 'POST', signal: abortController.signal, headers: new Headers({ - Authorization: `Bearer ${token}`, + [CSRF_HEADER_NAME]: Cookies.get(CSRF_COOKIE_NAME()) || '', + [WEB_APP_SHARE_CODE_HEADER_NAME]: shareCode, + [PASSPORT_HEADER_NAME]: getWebAppPassport(shareCode), }), } as RequestInit, fetchOptions) @@ -439,9 +450,6 @@ export const ssePost = async ( if (body) options.body = JSON.stringify(body) - const accessToken = await getAccessToken(isPublicAPI) - ; (options.headers as Headers).set('Authorization', `Bearer ${accessToken}`) - globalThis.fetch(urlWithPrefix, options as RequestInit) .then((res) => { if (!/^[23]\d{2}$/.test(String(res.status))) { @@ -452,15 +460,11 @@ export const ssePost = async ( if (data.code === 'web_app_access_denied') requiredWebSSOLogin(data.message, 403) - if (data.code === 'web_sso_auth_required') { - removeAccessToken() + if (data.code === 'web_sso_auth_required') requiredWebSSOLogin() - } - if (data.code === 'unauthorized') { - removeAccessToken() + if (data.code === 'unauthorized') requiredWebSSOLogin() - } } }) } @@ -551,13 +555,11 @@ export const request = async(url: string, options = {}, otherOptions?: IOther return Promise.reject(err) } if (code === 'web_sso_auth_required') { - removeAccessToken() requiredWebSSOLogin() return Promise.reject(err) } if (code === 'unauthorized_and_force_logout') { - localStorage.removeItem('console_token') - localStorage.removeItem('refresh_token') + // Cookies will be cleared by the backend globalThis.location.reload() return Promise.reject(err) } @@ -566,7 +568,6 @@ export const request = async(url: string, options = {}, otherOptions?: IOther silent, } = otherOptionsForBaseFetch if (isPublicAPI && code === 'unauthorized') { - removeAccessToken() requiredWebSSOLogin() return Promise.reject(err) } diff --git a/web/service/common.ts b/web/service/common.ts index d70315f5c6..8f2adc329e 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -40,7 +40,7 @@ import type { SystemFeatures } from '@/types/feature' type LoginSuccess = { result: 'success' - data: { access_token: string; refresh_token: string } + data: { access_token: string } } type LoginFail = { result: 'fail' @@ -56,10 +56,6 @@ export const webAppLogin: Fetcher } -export const fetchNewToken: Fetcher }> = ({ body }) => { - return post('/refresh-token', { body }) as Promise -} - export const setup: Fetcher }> = ({ body }) => { return post('/setup', { body }) } @@ -84,10 +80,6 @@ export const updateUserProfile: Fetcher(url, { body }) } -export const logout: Fetcher }> = ({ url, params }) => { - return get(url, params) -} - export const fetchLangGeniusVersion: Fetcher }> = ({ url, params }) => { return get(url, { params }) } diff --git a/web/service/fetch.ts b/web/service/fetch.ts index 4e76843ba2..541b1246d4 100644 --- a/web/service/fetch.ts +++ b/web/service/fetch.ts @@ -2,9 +2,9 @@ import type { AfterResponseHook, BeforeErrorHook, BeforeRequestHook, Hooks } fro import ky from 'ky' import type { IOtherOptions } from './base' import Toast from '@/app/components/base/toast' -import { API_PREFIX, APP_VERSION, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config' -import { getInitialTokenV2, isTokenV1 } from '@/app/components/share/utils' -import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils' +import { API_PREFIX, APP_VERSION, CSRF_COOKIE_NAME, CSRF_HEADER_NAME, MARKETPLACE_API_PREFIX, PASSPORT_HEADER_NAME, PUBLIC_API_PREFIX, WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config' +import Cookies from 'js-cookie' +import { getWebAppAccessToken, getWebAppPassport } from './webapp-auth' const TIME_OUT = 100000 @@ -69,35 +69,15 @@ const beforeErrorToast = (otherOptions: IOtherOptions): BeforeErrorHook => { } } -export async function getAccessToken(isPublicAPI?: boolean) { - if (isPublicAPI) { - const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] - const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id - const accessToken = localStorage.getItem('token') || JSON.stringify({ version: 2 }) - let accessTokenJson: Record = { version: 2 } - try { - accessTokenJson = JSON.parse(accessToken) - if (isTokenV1(accessTokenJson)) - accessTokenJson = getInitialTokenV2() - } - catch { - - } - return accessTokenJson[sharedToken]?.[userId || 'DEFAULT'] - } - else { - return localStorage.getItem('console_token') || '' - } -} - -const beforeRequestPublicAuthorization: BeforeRequestHook = async (request) => { - const token = await getAccessToken(true) - request.headers.set('Authorization', `Bearer ${token}`) -} - -const beforeRequestAuthorization: BeforeRequestHook = async (request) => { - const accessToken = await getAccessToken() - request.headers.set('Authorization', `Bearer ${accessToken}`) +const beforeRequestPublicWithCode = (request: Request) => { + request.headers.set('Authorization', `Bearer ${getWebAppAccessToken()}`) + const shareCode = globalThis.location.pathname.split('/').filter(Boolean).pop() || '' + // some pages does not end with share code, so we need to check it + // TODO: maybe find a better way to access app code? + if (shareCode === 'webapp-signin' || shareCode === 'check-code') + return + request.headers.set(WEB_APP_SHARE_CODE_HEADER_NAME, shareCode) + request.headers.set(PASSPORT_HEADER_NAME, getWebAppPassport(shareCode)) } const baseHooks: Hooks = { @@ -148,6 +128,8 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: } const fetchPathname = base + (url.startsWith('/') ? url : `/${url}`) + if (!isMarketplaceAPI) + (headers as any).set(CSRF_HEADER_NAME, Cookies.get(CSRF_COOKIE_NAME()) || '') if (deleteContentType) (headers as any).delete('Content-Type') @@ -165,8 +147,7 @@ async function base(url: string, options: FetchOptionType = {}, otherOptions: ], beforeRequest: [ ...baseHooks.beforeRequest || [], - isPublicAPI && beforeRequestPublicAuthorization, - !isPublicAPI && !isMarketplaceAPI && beforeRequestAuthorization, + isPublicAPI && beforeRequestPublicWithCode, ].filter((h): h is BeforeRequestHook => Boolean(h)), afterResponse: [ ...baseHooks.afterResponse || [], diff --git a/web/service/refresh-token.ts b/web/service/refresh-token.ts index 7eff08b52f..3f63f628a1 100644 --- a/web/service/refresh-token.ts +++ b/web/service/refresh-token.ts @@ -39,7 +39,6 @@ async function getNewAccessToken(timeout: number): Promise { globalThis.localStorage.setItem(LOCAL_STORAGE_KEY, '1') globalThis.localStorage.setItem('last_refresh_time', new Date().getTime().toString()) globalThis.addEventListener('beforeunload', releaseRefreshLock) - const refresh_token = globalThis.localStorage.getItem('refresh_token') // Do not use baseFetch to refresh tokens. // If a 401 response occurs and baseFetch itself attempts to refresh the token, @@ -48,10 +47,11 @@ async function getNewAccessToken(timeout: number): Promise { // that does not call baseFetch and uses a single retry mechanism. const [error, ret] = await fetchWithRetry(globalThis.fetch(`${API_PREFIX}/refresh-token`, { method: 'POST', + credentials: 'include', // Important: include cookies in the request headers: { 'Content-Type': 'application/json;utf-8', }, - body: JSON.stringify({ refresh_token }), + // No body needed - refresh token is in cookie })) if (error) { return Promise.reject(error) @@ -59,10 +59,6 @@ async function getNewAccessToken(timeout: number): Promise { else { if (ret.status === 401) return Promise.reject(ret) - - const { data } = await ret.json() - globalThis.localStorage.setItem('console_token', data.access_token) - globalThis.localStorage.setItem('refresh_token', data.refresh_token) } } } diff --git a/web/service/share.ts b/web/service/share.ts index ab8e0deb4a..ce03f508d1 100644 --- a/web/service/share.ts +++ b/web/service/share.ts @@ -34,6 +34,8 @@ import type { } from '@/models/share' import type { ChatConfig } from '@/app/components/base/chat/types' import type { AccessMode } from '@/models/access-control' +import { WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config' +import { getWebAppAccessToken } from './webapp-auth' function getAction(action: 'get' | 'post' | 'del' | 'patch', isInstalledApp: boolean) { switch (action) { @@ -286,16 +288,14 @@ export const textToAudioStream = (url: string, isPublicAPI: boolean, header: { c return (getAction('post', !isPublicAPI))(url, { body, header }, { needAllResponseContent: true }) } -export const fetchAccessToken = async ({ appCode, userId, webAppAccessToken }: { appCode: string, userId?: string, webAppAccessToken?: string | null }) => { +export const fetchAccessToken = async ({ userId, appCode }: { userId?: string, appCode: string }) => { const headers = new Headers() - headers.append('X-App-Code', appCode) + headers.append(WEB_APP_SHARE_CODE_HEADER_NAME, appCode) + headers.append('Authorization', `Bearer ${getWebAppAccessToken()}`) const params = new URLSearchParams() - if (webAppAccessToken) - params.append('web_app_access_token', webAppAccessToken) - if (userId) - params.append('user_id', userId) + userId && params.append('user_id', userId) const url = `/passport?${params.toString()}` - return get(url, { headers }) as Promise<{ access_token: string }> + return get<{ access_token: string }>(url, { headers }) as Promise<{ access_token: string }> } export const getUserCanAccess = (appId: string, isInstalledApp: boolean) => { diff --git a/web/service/use-common.ts b/web/service/use-common.ts index 330ee674b0..3e01b721e8 100644 --- a/web/service/use-common.ts +++ b/web/service/use-common.ts @@ -50,7 +50,7 @@ export const useMailValidity = () => { }) } -export type MailRegisterResponse = { result: string, data: { access_token: string, refresh_token: string } } +export type MailRegisterResponse = { result: string, data: {} } export const useMailRegister = () => { return useMutation({ @@ -106,3 +106,23 @@ export const useSchemaTypeDefinitions = () => { queryFn: () => get('/spec/schema-definitions'), }) } + +type isLogin = { + logged_in: boolean +} + +export const useIsLogin = () => { + return useQuery({ + queryKey: [NAME_SPACE, 'is-login'], + staleTime: 0, + gcTime: 0, + queryFn: () => get('/login/status'), + }) +} + +export const useLogout = () => { + return useMutation({ + mutationKey: [NAME_SPACE, 'logout'], + mutationFn: () => post('/logout'), + }) +} diff --git a/web/service/use-share.ts b/web/service/use-share.ts index 267975fd38..a5e0a11100 100644 --- a/web/service/use-share.ts +++ b/web/service/use-share.ts @@ -8,6 +8,8 @@ export const useGetWebAppAccessModeByCode = (code: string | null) => { queryKey: [NAME_SPACE, 'appAccessMode', code], queryFn: () => getAppAccessModeByAppCode(code!), enabled: !!code, + staleTime: 0, // backend change the access mode may cause the logic error. Because /permission API is no cached. + gcTime: 0, }) } diff --git a/web/service/webapp-auth.ts b/web/service/webapp-auth.ts new file mode 100644 index 0000000000..a7ce7153bf --- /dev/null +++ b/web/service/webapp-auth.ts @@ -0,0 +1,53 @@ +import { ACCESS_TOKEN_LOCAL_STORAGE_NAME, PASSPORT_LOCAL_STORAGE_NAME } from '@/config' +import { getPublic, postPublic } from './base' + +export function setWebAppAccessToken(token: string) { + localStorage.setItem(ACCESS_TOKEN_LOCAL_STORAGE_NAME, token) +} + +export function setWebAppPassport(shareCode: string, token: string) { + localStorage.setItem(PASSPORT_LOCAL_STORAGE_NAME(shareCode), token) +} + +export function getWebAppAccessToken() { + return localStorage.getItem(ACCESS_TOKEN_LOCAL_STORAGE_NAME) || '' +} + +export function getWebAppPassport(shareCode: string) { + return localStorage.getItem(PASSPORT_LOCAL_STORAGE_NAME(shareCode)) || '' +} + +export function clearWebAppAccessToken() { + localStorage.removeItem(ACCESS_TOKEN_LOCAL_STORAGE_NAME) +} + +export function clearWebAppPassport(shareCode: string) { + localStorage.removeItem(PASSPORT_LOCAL_STORAGE_NAME(shareCode)) +} + +type isWebAppLogin = { + logged_in: boolean + app_logged_in: boolean +} + +export async function webAppLoginStatus(enabled: boolean, shareCode: string) { + if (!enabled) { + return { + userLoggedIn: true, + appLoggedIn: true, + } + } + + // check remotely, the access token could be in cookie (enterprise SSO redirected with https) + const { logged_in, app_logged_in } = await getPublic(`/login/status?app_code=${shareCode}`) + return { + userLoggedIn: logged_in, + appLoggedIn: app_logged_in, + } +} + +export async function webAppLogout(shareCode: string) { + clearWebAppAccessToken() + clearWebAppPassport(shareCode) + await postPublic('/logout') +}