fix: adjust interceptor for web apis

This commit is contained in:
GareArc 2025-05-28 17:56:31 +08:00
parent 2616f89d46
commit 60ce8f6053
No known key found for this signature in database
3 changed files with 39 additions and 20 deletions

View File

@ -27,11 +27,13 @@ class PassportResource(Resource):
if app_code is None: if app_code is None:
raise Unauthorized("X-App-Code header is missing.") raise Unauthorized("X-App-Code header is missing.")
# logic for exchange token for enterprise logined web user # exchange token for enterprise logined web user
enterprise_user_id = decode_enterprise_webapp_user_id(enterprise_login_token) enterprise_user_decoded = decode_enterprise_webapp_user_id(enterprise_login_token)
if enterprise_user_id: if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page # a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(app_code=app_code, user_id=enterprise_user_id) return exchange_token_for_existing_web_user(
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
)
if system_features.webapp_auth.enabled: if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
@ -112,15 +114,16 @@ def decode_enterprise_webapp_user_id(auth_header: str | None):
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
source = decoded.get("token_source") source = decoded.get("token_source")
if not source or source != "enterprise_login": if not source or source != "enterprise_login":
return None raise Unauthorized("Invalid token source. Expected 'enterprise_login'.")
user_id: str | None = decoded.get("user_id") return decoded
return user_id
def exchange_token_for_existing_web_user(app_code: str, user_id: str): def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
""" """
Exchange a token for an existing web user session. Exchange a token for an existing web user session.
""" """
user_id = enterprise_user_decoded.get("user_id")
end_user_id = enterprise_user_decoded.get("end_user_id")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
if not site: if not site:
@ -129,7 +132,7 @@ def exchange_token_for_existing_web_user(app_code: str, user_id: str):
app_model = db.session.query(App).filter(App.id == site.app_id).first() app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
end_user = db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user: if not end_user:
end_user = EndUser( end_user = EndUser(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,

View File

@ -8,7 +8,7 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire
from extensions.ext_database import db from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -44,7 +44,9 @@ def decode_jwt_token():
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") decoded_app_code = decoded.get("app_code")
if not decoded_app_code or decoded_app_code != app_code:
raise Unauthorized("Invalid app code in token.")
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first()
site = db.session.query(Site).filter(Site.code == app_code).first() site = db.session.query(Site).filter(Site.code == app_code).first()
if not app_model: if not app_model:
@ -59,13 +61,17 @@ def decode_jwt_token():
# for enterprise webapp auth # for enterprise webapp auth
app_web_auth_enabled = False app_web_auth_enabled = False
webapp_settings = None
if system_features.webapp_auth.enabled: if system_features.webapp_auth.enabled:
app_web_auth_enabled = ( webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" if not webapp_settings:
) raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public"
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_user_accessibility(
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
)
return app_model, end_user return app_model, end_user
except Unauthorized as e: except Unauthorized as e:
@ -95,15 +101,27 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au
raise Unauthorized("webapp token expired.") raise Unauthorized("webapp token expired.")
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): def _validate_user_accessibility(
decoded,
app_code,
app_web_auth_enabled: bool,
system_webapp_auth_enabled: bool,
webapp_settings: WebAppSettings | None,
):
if system_webapp_auth_enabled and app_web_auth_enabled: if system_webapp_auth_enabled and app_web_auth_enabled:
# Check if the user is allowed to access the web app # Check if the user is allowed to access the web app
user_id = decoded.get("user_id") user_id = decoded.get("user_id")
if not user_id: if not user_id:
raise WebAppAuthRequiredError() raise WebAppAuthRequiredError()
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): if not webapp_settings:
raise WebAppAuthAccessDeniedError() raise WebAppAuthRequiredError("Web app settings not found.")
access_modes_require_permission_check = ["private", "private_all"]
if webapp_settings.access_mode in access_modes_require_permission_check:
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthAccessDeniedError()
class WebApiResource(Resource): class WebApiResource(Resource):

View File

@ -129,8 +129,6 @@ class WebAppAuthService:
payload = { payload = {
"iss": site.id, "iss": site.id,
"sub": "Web API Passport", "sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": account.id, "user_id": account.id,
"end_user_id": end_user_id, "end_user_id": end_user_id,
"token_source": "enterprise_login", "token_source": "enterprise_login",