mirror of
https://github.com/langgenius/dify.git
synced 2026-06-23 04:11:09 +08:00
handle enduser in decorator
This commit is contained in:
parent
39bf04e7fe
commit
3c8d03d24f
@ -1,9 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.rbac import RBACPermission, RBACResourceScope
|
||||
|
||||
__all__ = ["RBACPermission", "RBACResourceScope", "rbac_permission_required"]
|
||||
if TYPE_CHECKING:
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
|
||||
__all__ = ["RBACPermission", "RBACResourceScope", "openapi_rbac_permission_required", "rbac_permission_required"]
|
||||
|
||||
|
||||
def openapi_rbac_permission_required[**P, R](
|
||||
resource_type: RBACResourceScope,
|
||||
scene: RBACPermission,
|
||||
*,
|
||||
resource_required: bool = True,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""RBAC guard for OpenAPI endpoints that may be called by either an Account or an EndUser.
|
||||
"""
|
||||
inner = rbac_permission_required(resource_type, scene, resource_required=resource_required)
|
||||
|
||||
def decorator(view: Callable[P, R]) -> Callable[P, R]:
|
||||
guarded = inner(view)
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
auth_data: AuthData | None = kwargs.get("auth_data")
|
||||
if auth_data is not None and auth_data.caller_kind == "end_user":
|
||||
# we can skip rbac for enduser for now.
|
||||
return view(*args, **kwargs)
|
||||
return guarded(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def rbac_permission_required[**P, R](
|
||||
|
||||
@ -19,7 +19,7 @@ from werkzeug.exceptions import (
|
||||
|
||||
import services
|
||||
from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope, rbac_permission_required
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope, openapi_rbac_permission_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
@ -138,7 +138,7 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@openapi_ns.response(200, "Run result (SSE stream)", openapi_ns.models[EventStreamResponse.__name__])
|
||||
@accepts(body=AppRunRequest)
|
||||
def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest):
|
||||
@ -170,7 +170,7 @@ class AppRunApi(Resource):
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@returns(200, TaskStopResponse, description="Task stopped")
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
|
||||
@ -16,7 +16,7 @@ from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope, rbac_permission_required
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope, openapi_rbac_permission_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import FormSubmitResponse, HumanInputFormDefinitionResponse
|
||||
@ -60,7 +60,7 @@ def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition", openapi_ns.models[HumanInputFormDefinitionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
service = HumanInputService(db.engine)
|
||||
@ -74,7 +74,7 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@returns(200, FormSubmitResponse, description="Form submitted")
|
||||
@accepts(body=HumanInputFormSubmitPayload)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload):
|
||||
|
||||
@ -19,7 +19,7 @@ from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope, rbac_permission_required
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope, openapi_rbac_permission_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -48,7 +48,7 @@ class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(WorkflowEventsQuery))
|
||||
@openapi_ns.response(200, "SSE event stream", openapi_ns.models[EventStreamResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@openapi_rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user