mirror of https://github.com/langgenius/dify.git
fix: adjust interceptor for web apis
This commit is contained in:
parent
2616f89d46
commit
60ce8f6053
|
|
@ -27,11 +27,13 @@ class PassportResource(Resource):
|
|||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
|
||||
# logic for exchange token for enterprise logined web user
|
||||
enterprise_user_id = decode_enterprise_webapp_user_id(enterprise_login_token)
|
||||
if enterprise_user_id:
|
||||
# exchange token for enterprise logined web user
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(enterprise_login_token)
|
||||
if enterprise_user_decoded:
|
||||
# a web user has already logged in, exchange a token for this app without redirecting to the login page
|
||||
return exchange_token_for_existing_web_user(app_code=app_code, user_id=enterprise_user_id)
|
||||
return exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
|
||||
)
|
||||
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
|
|
@ -112,15 +114,16 @@ def decode_enterprise_webapp_user_id(auth_header: str | None):
|
|||
decoded = PassportService().verify(tk)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "enterprise_login":
|
||||
return None
|
||||
user_id: str | None = decoded.get("user_id")
|
||||
return user_id
|
||||
raise Unauthorized("Invalid token source. Expected 'enterprise_login'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, user_id: str):
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
user_id = enterprise_user_decoded.get("user_id")
|
||||
end_user_id = enterprise_user_decoded.get("end_user_id")
|
||||
|
||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||
if not site:
|
||||
|
|
@ -129,7 +132,7 @@ def exchange_token_for_existing_web_user(app_code: str, user_id: str):
|
|||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
end_user = db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire
|
|||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
|
|
@ -44,7 +44,9 @@ def decode_jwt_token():
|
|||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
decoded_app_code = decoded.get("app_code")
|
||||
if not decoded_app_code or decoded_app_code != app_code:
|
||||
raise Unauthorized("Invalid app code in token.")
|
||||
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first()
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
if not app_model:
|
||||
|
|
@ -59,13 +61,17 @@ def decode_jwt_token():
|
|||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
webapp_settings = None
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||
)
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(
|
||||
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
|
||||
)
|
||||
|
||||
return app_model, end_user
|
||||
except Unauthorized as e:
|
||||
|
|
@ -95,15 +101,27 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au
|
|||
raise Unauthorized("webapp token expired.")
|
||||
|
||||
|
||||
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||
def _validate_user_accessibility(
|
||||
decoded,
|
||||
app_code,
|
||||
app_web_auth_enabled: bool,
|
||||
system_webapp_auth_enabled: bool,
|
||||
webapp_settings: WebAppSettings | None,
|
||||
):
|
||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||
# Check if the user is allowed to access the web app
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
if not webapp_settings:
|
||||
raise WebAppAuthRequiredError("Web app settings not found.")
|
||||
|
||||
access_modes_require_permission_check = ["private", "private_all"]
|
||||
|
||||
if webapp_settings.access_mode in access_modes_require_permission_check:
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
|
||||
class WebApiResource(Resource):
|
||||
|
|
|
|||
|
|
@ -129,8 +129,6 @@ class WebAppAuthService:
|
|||
payload = {
|
||||
"iss": site.id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"user_id": account.id,
|
||||
"end_user_id": end_user_id,
|
||||
"token_source": "enterprise_login",
|
||||
|
|
|
|||
Loading…
Reference in New Issue