diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 0a7ba552ee..ed67c24bd3 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -514,6 +514,17 @@ def with_current_tenant_id[T, **P, R]( def with_current_user[T, **P, R]( view: Callable[Concatenate[T, Account, P], R], ) -> Callable[Concatenate[T, P], R]: + """Inject the current authenticated Account into the handler as the first argument after self. + + Usage:: + + class MyResource(Resource): + @login_required + @with_current_user + def get(self, current_user: Account): + ... + """ + @wraps(view) def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R: current_user, _ = current_account_with_tenant() @@ -522,6 +533,30 @@ def with_current_user[T, **P, R]( return decorated +def with_current_user_id[T, **P, R]( + view: Callable[Concatenate[T, str, P], R], +) -> Callable[Concatenate[T, P], R]: + """Inject the current authenticated user's ID (as a string) into the handler. + + Use this when the handler only needs the user ID and not the full Account object. + + Usage:: + + class MyResource(Resource): + @login_required + @with_current_user_id + def get(self, current_user_id: str): + ... + """ + + @wraps(view) + def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R: + current_user, _ = current_account_with_tenant() + return view(self, str(current_user.id), *args, **kwargs) + + return decorated + + def model_validate[T, M: BaseModel, **P, R]( model: type[M], ) -> Callable[ diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 714b114752..fb2ef55fe8 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -22,6 +22,7 @@ from controllers.console.wraps import ( setup_required, with_current_tenant_id, with_current_user, + with_current_user_id, ) from models import Account from models.account import AccountStatus, TenantAccountRole @@ -124,6 +125,19 @@ class TestCurrentContextInjection: with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")): assert Handler().get() is current_user + def test_with_current_user_id_injects_user_id_string(self): + current_user = make_account("user-42") + + class Handler: + @with_current_user_id + def get(self, current_user_id: str): + return current_user_id + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")): + result = Handler().get() + assert result == "user-42" + assert isinstance(result, str) + def test_stacked_current_context_injectors_preserve_argument_order(self): current_user = make_account() @@ -136,6 +150,18 @@ class TestCurrentContextInjection: with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")): assert Handler().get() == ("tenant-123", current_user) + def test_stacked_user_id_and_tenant_id_injectors(self): + current_user = make_account("user-99") + + class Handler: + @with_current_user_id + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user_id: str): + return current_user_id, current_tenant_id + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-456")): + assert Handler().get() == ("user-99", "tenant-456") + class TestModelValidationInjection: """Test request model validation decorator."""