From 3c8d03d24f0789097381df3d6241430e8a60c80c Mon Sep 17 00:00:00 2001 From: "yunlu.wen" Date: Wed, 17 Jun 2026 11:00:56 +0800 Subject: [PATCH] handle enduser in decorator --- api/controllers/common/wraps.py | 34 ++++++++++++++++++++- api/controllers/openapi/app_run.py | 6 ++-- api/controllers/openapi/human_input_form.py | 6 ++-- api/controllers/openapi/workflow_events.py | 4 +-- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/api/controllers/common/wraps.py b/api/controllers/common/wraps.py index c481f6eca94..ce9a3740d44 100644 --- a/api/controllers/common/wraps.py +++ b/api/controllers/common/wraps.py @@ -1,9 +1,41 @@ +from __future__ import annotations + from collections.abc import Callable from functools import wraps +from typing import TYPE_CHECKING from core.rbac import RBACPermission, RBACResourceScope -__all__ = ["RBACPermission", "RBACResourceScope", "rbac_permission_required"] +if TYPE_CHECKING: + from controllers.openapi.auth.data import AuthData + +__all__ = ["RBACPermission", "RBACResourceScope", "openapi_rbac_permission_required", "rbac_permission_required"] + + +def openapi_rbac_permission_required[**P, R]( + resource_type: RBACResourceScope, + scene: RBACPermission, + *, + resource_required: bool = True, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """RBAC guard for OpenAPI endpoints that may be called by either an Account or an EndUser. + """ + inner = rbac_permission_required(resource_type, scene, resource_required=resource_required) + + def decorator(view: Callable[P, R]) -> Callable[P, R]: + guarded = inner(view) + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs) -> R: + auth_data: AuthData | None = kwargs.get("auth_data") + if auth_data is not None and auth_data.caller_kind == "end_user": + # we can skip rbac for enduser for now. + return view(*args, **kwargs) + return guarded(*args, **kwargs) + + return decorated + + return decorator def rbac_permission_required[**P, R]( diff --git a/api/controllers/openapi/app_run.py b/api/controllers/openapi/app_run.py index 7f214480110..1cec04844fb 100644 --- a/api/controllers/openapi/app_run.py +++ b/api/controllers/openapi/app_run.py @@ -19,7 +19,7 @@ from werkzeug.exceptions import ( import services from controllers.common.fields import EventStreamResponse -from controllers.common.wraps import RBACPermission, RBACResourceScope, rbac_permission_required +from controllers.common.wraps import RBACPermission, RBACResourceScope, openapi_rbac_permission_required from controllers.openapi import openapi_ns from controllers.openapi._audit import emit_app_run from controllers.openapi._contract import accepts, returns @@ -138,7 +138,7 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = { @openapi_ns.route("/apps//run") class AppRunApi(Resource): @auth_router.guard(scope=Scope.APPS_RUN) - @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) + @openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) @openapi_ns.response(200, "Run result (SSE stream)", openapi_ns.models[EventStreamResponse.__name__]) @accepts(body=AppRunRequest) def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest): @@ -170,7 +170,7 @@ class AppRunApi(Resource): @openapi_ns.route("/apps//tasks//stop") class AppRunTaskStopApi(Resource): @auth_router.guard(scope=Scope.APPS_RUN) - @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) + @openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) @returns(200, TaskStopResponse, description="Task stopped") def post(self, app_id: str, task_id: str, *, auth_data: AuthData): app_model, caller, caller_kind = auth_data.require_app_context() diff --git a/api/controllers/openapi/human_input_form.py b/api/controllers/openapi/human_input_form.py index 51a0b49de20..6148c2ab8de 100644 --- a/api/controllers/openapi/human_input_form.py +++ b/api/controllers/openapi/human_input_form.py @@ -16,7 +16,7 @@ from werkzeug.exceptions import BadRequest, NotFound from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values from controllers.common.schema import register_schema_models -from controllers.common.wraps import RBACPermission, RBACResourceScope, rbac_permission_required +from controllers.common.wraps import RBACPermission, RBACResourceScope, openapi_rbac_permission_required from controllers.openapi import openapi_ns from controllers.openapi._contract import accepts, returns from controllers.openapi._models import FormSubmitResponse, HumanInputFormDefinitionResponse @@ -60,7 +60,7 @@ def _ensure_form_is_allowed_for_openapi(form) -> None: class OpenApiWorkflowHumanInputFormApi(Resource): @openapi_ns.response(200, "Form definition", openapi_ns.models[HumanInputFormDefinitionResponse.__name__]) @auth_router.guard(scope=Scope.APPS_RUN) - @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) + @openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) def get(self, app_id: str, form_token: str, *, auth_data: AuthData): app_model, caller, caller_kind = auth_data.require_app_context() service = HumanInputService(db.engine) @@ -74,7 +74,7 @@ class OpenApiWorkflowHumanInputFormApi(Resource): return _jsonify_form_definition(form) @auth_router.guard(scope=Scope.APPS_RUN) - @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) + @openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) @returns(200, FormSubmitResponse, description="Form submitted") @accepts(body=HumanInputFormSubmitPayload) def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload): diff --git a/api/controllers/openapi/workflow_events.py b/api/controllers/openapi/workflow_events.py index 7a4c657bd61..77f60224a48 100644 --- a/api/controllers/openapi/workflow_events.py +++ b/api/controllers/openapi/workflow_events.py @@ -19,7 +19,7 @@ from werkzeug.exceptions import NotFound, UnprocessableEntity from controllers.common.fields import EventStreamResponse from controllers.common.schema import query_params_from_model -from controllers.common.wraps import RBACPermission, RBACResourceScope, rbac_permission_required +from controllers.common.wraps import RBACPermission, RBACResourceScope, openapi_rbac_permission_required from controllers.openapi import openapi_ns from controllers.openapi.auth.composition import auth_router from controllers.openapi.auth.data import AuthData @@ -48,7 +48,7 @@ class OpenApiWorkflowEventsApi(Resource): @openapi_ns.doc(params=query_params_from_model(WorkflowEventsQuery)) @openapi_ns.response(200, "SSE event stream", openapi_ns.models[EventStreamResponse.__name__]) @auth_router.guard(scope=Scope.APPS_RUN) - @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) + @openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN) def get(self, app_id: str, task_id: str, *, auth_data: AuthData): app_model, caller, caller_kind = auth_data.require_app_context() app_mode = AppMode.value_of(app_model.mode)