refactor: refactor rbac backend implement (#35900)

Co-authored-by: twwu <twwu@dify.ai>
This commit is contained in:
wangxiaolei 2026-05-08 00:04:54 +08:00 committed by GitHub
parent 3a525a609c
commit 9216d74c61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 265 additions and 156 deletions

View File

@ -1,28 +1,16 @@
"""Dify Console controllers that proxy the enterprise RBAC surface.
Each route here is a thin adapter: it validates the pydantic payload shown in
the screenshots (`Settings > Permissions`, `Settings > Access Rules`,
`App/Knowledge Base Access Config`, and the `Settings > Members` role
assignment dialog), pulls ``tenant_id`` / ``account_id`` from the current
Dify session and forwards to the inner RBAC client defined in
``services/enterprise/rbac_service.py``. The client then calls the
``/inner/api/rbac/*`` endpoints on dify-enterprise over HTTP using the
shared ``Enterprise-Api-Secret-Key`` header.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, ValidationError
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationError
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_account_with_tenant, login_required
from services.enterprise import rbac_service as svc
@ -75,6 +63,23 @@ def _dump(model: BaseModel) -> dict[str, Any]:
return model.model_dump(mode="json")
class _PaginationQuery(BaseModel):
model_config = ConfigDict(extra="ignore")
page_number: int | None = Field(default=None, ge=1, validation_alias=AliasChoices("page", "page_number"))
results_per_page: int | None = Field(
default=None, ge=1, le=100, validation_alias=AliasChoices("limit", "results_per_page")
)
reverse: bool | None = None
def to_inner_options(self) -> svc.ListOption:
return svc.ListOption.model_validate(self.model_dump())
def _pagination_options() -> svc.ListOption:
return _PaginationQuery.model_validate(request.args.to_dict(flat=True)).to_inner_options()
# ---------------------------------------------------------------------------
# Permission catalogs.
# ---------------------------------------------------------------------------
@ -83,9 +88,7 @@ def _dump(model: BaseModel) -> dict[str, Any]:
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
class RBACWorkspaceCatalogApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.workspace(tenant_id, account_id))
@ -94,9 +97,7 @@ class RBACWorkspaceCatalogApi(Resource):
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/app")
class RBACAppCatalogApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.app(tenant_id, account_id))
@ -105,9 +106,7 @@ class RBACAppCatalogApi(Resource):
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/dataset")
class RBACDatasetCatalogApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Catalog.dataset(tenant_id, account_id))
@ -122,14 +121,12 @@ class _RoleUpsertRequest(BaseModel):
"""Accepts the payload sent by the Create/Edit Role dialog."""
name: str
role_key: str
description: str = ""
permission_keys: list[str] = []
def to_mutation(self) -> svc.RoleMutation:
return svc.RoleMutation(
name=self.name,
role_key=self.role_key,
description=self.description,
permission_keys=list(self.permission_keys),
)
@ -138,18 +135,14 @@ class _RoleUpsertRequest(BaseModel):
@console_ns.route("/workspaces/current/rbac/roles")
class RBACRolesApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id, account_id = _current_ids()
options = svc.ListOption()
options = _pagination_options()
return _dump(svc.RBACService.Roles.list(tenant_id, account_id, options=options))
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def post(self):
tenant_id, account_id = _current_ids()
request = _payload(_RoleUpsertRequest)
@ -160,17 +153,13 @@ class RBACRolesApi(Resource):
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>")
class RBACRoleItemApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, role_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.Roles.get(tenant_id, account_id, str(role_id)))
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, role_id):
tenant_id, account_id = _current_ids()
request = _payload(_RoleUpsertRequest)
@ -178,9 +167,7 @@ class RBACRoleItemApi(Resource):
return _dump(role)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def delete(self, role_id):
tenant_id, account_id = _current_ids()
svc.RBACService.Roles.delete(tenant_id, account_id, str(role_id))
@ -208,29 +195,23 @@ class _AccessPolicyUpdateRequest(BaseModel):
@console_ns.route("/workspaces/current/rbac/access-policies")
class RBACAccessPoliciesApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id, account_id = _current_ids()
# `resource_type` is exposed as a query argument so the UI can show
# only app-scoped or only dataset-scoped permission sets.
from flask import request
resource_type = request.args.get("resource_type") or None
return _dump(
svc.RBACService.AccessPolicies.list(
tenant_id,
account_id,
resource_type=resource_type,
options=svc.ListOption(),
options=_pagination_options(),
)
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def post(self):
tenant_id, account_id = _current_ids()
request = _payload(_AccessPolicyCreateRequest)
@ -250,17 +231,13 @@ class RBACAccessPoliciesApi(Resource):
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>")
class RBACAccessPolicyItemApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.AccessPolicies.get(tenant_id, account_id, str(policy_id)))
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_AccessPolicyUpdateRequest)
@ -277,9 +254,7 @@ class RBACAccessPolicyItemApi(Resource):
return _dump(policy)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def delete(self, policy_id):
tenant_id, account_id = _current_ids()
svc.RBACService.AccessPolicies.delete(tenant_id, account_id, str(policy_id))
@ -289,9 +264,7 @@ class RBACAccessPolicyItemApi(Resource):
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>/copy")
class RBACAccessPolicyCopyApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def post(self, policy_id):
tenant_id, account_id = _current_ids()
policy = svc.RBACService.AccessPolicies.copy(tenant_id, account_id, str(policy_id))
@ -304,19 +277,34 @@ class RBACAccessPolicyCopyApi(Resource):
class _ReplaceRoleBindingsRequest(BaseModel):
role_keys: list[str] = []
role_ids: list[str] = []
class _ReplaceMemberBindingsRequest(BaseModel):
account_ids: list[str] = []
@console_ns.route("/workspaces/current/rbac/my-permissions")
class RBACMyPermissionsApi(Resource):
@enterprise_only
@login_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(
svc.RBACService.MyPermissions.get(
tenant_id,
account_id,
app_id=request.args.get("app_id") or None,
dataset_id=request.args.get("dataset_id") or None,
)
)
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policy")
class RBACAppMatrixApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, app_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.AppAccess.matrix(tenant_id, account_id, str(app_id)))
@ -325,9 +313,7 @@ class RBACAppMatrixApi(Resource):
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/role-bindings")
class RBACAppRoleBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -335,9 +321,7 @@ class RBACAppRoleBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceRoleBindingsRequest)
@ -347,7 +331,7 @@ class RBACAppRoleBindingsApi(Resource):
account_id,
str(app_id),
str(policy_id),
svc.ReplaceRoleBindings(role_keys=list(request.role_keys)),
svc.ReplaceRoleBindings(role_ids=list(request.role_ids)),
)
)
@ -355,9 +339,7 @@ class RBACAppRoleBindingsApi(Resource):
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/member-bindings")
class RBACAppMemberBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -365,9 +347,7 @@ class RBACAppMemberBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, app_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceMemberBindingsRequest)
@ -390,9 +370,7 @@ class RBACAppMemberBindingsApi(Resource):
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policy")
class RBACDatasetMatrixApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.DatasetAccess.matrix(tenant_id, account_id, str(dataset_id)))
@ -401,9 +379,7 @@ class RBACDatasetMatrixApi(Resource):
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/role-bindings")
class RBACDatasetRoleBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -413,9 +389,7 @@ class RBACDatasetRoleBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceRoleBindingsRequest)
@ -425,7 +399,7 @@ class RBACDatasetRoleBindingsApi(Resource):
account_id,
str(dataset_id),
str(policy_id),
svc.ReplaceRoleBindings(role_keys=list(request.role_keys)),
svc.ReplaceRoleBindings(role_ids=list(request.role_ids)),
)
)
@ -435,9 +409,7 @@ class RBACDatasetRoleBindingsApi(Resource):
)
class RBACDatasetMemberBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -447,9 +419,7 @@ class RBACDatasetMemberBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceMemberBindingsRequest)
@ -472,20 +442,17 @@ class RBACDatasetMemberBindingsApi(Resource):
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
class RBACWorkspaceAppMatrixApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id))
options = _pagination_options()
return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id, options=options))
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/role-bindings")
class RBACWorkspaceAppRoleBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -493,9 +460,7 @@ class RBACWorkspaceAppRoleBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceRoleBindingsRequest)
@ -504,7 +469,7 @@ class RBACWorkspaceAppRoleBindingsApi(Resource):
tenant_id,
account_id,
str(policy_id),
svc.ReplaceRoleBindings(role_keys=list(request.role_keys)),
svc.ReplaceRoleBindings(role_ids=list(request.role_ids)),
)
)
@ -512,9 +477,7 @@ class RBACWorkspaceAppRoleBindingsApi(Resource):
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/member-bindings")
class RBACWorkspaceAppMemberBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -522,9 +485,7 @@ class RBACWorkspaceAppMemberBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceMemberBindingsRequest)
@ -541,20 +502,17 @@ class RBACWorkspaceAppMemberBindingsApi(Resource):
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy")
class RBACWorkspaceDatasetMatrixApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id))
options = _pagination_options()
return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id, options=options))
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/role-bindings")
class RBACWorkspaceDatasetRoleBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -562,9 +520,7 @@ class RBACWorkspaceDatasetRoleBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceRoleBindingsRequest)
@ -573,7 +529,7 @@ class RBACWorkspaceDatasetRoleBindingsApi(Resource):
tenant_id,
account_id,
str(policy_id),
svc.ReplaceRoleBindings(role_keys=list(request.role_keys)),
svc.ReplaceRoleBindings(role_ids=list(request.role_ids)),
)
)
@ -581,9 +537,7 @@ class RBACWorkspaceDatasetRoleBindingsApi(Resource):
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/member-bindings")
class RBACWorkspaceDatasetMemberBindingsApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, policy_id):
tenant_id, account_id = _current_ids()
return _dump(
@ -591,9 +545,7 @@ class RBACWorkspaceDatasetMemberBindingsApi(Resource):
)
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, policy_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceMemberBindingsRequest)
@ -613,23 +565,19 @@ class RBACWorkspaceDatasetMemberBindingsApi(Resource):
class _ReplaceMemberRolesRequest(BaseModel):
role_keys: list[str] = []
role_ids: list[str] = []
@console_ns.route("/workspaces/current/rbac/members/<uuid:member_id>/rbac-roles")
class RBACMemberRolesApi(Resource):
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def get(self, member_id):
tenant_id, account_id = _current_ids()
return _dump(svc.RBACService.MemberRoles.get(tenant_id, account_id, str(member_id)))
@enterprise_only
@setup_required
@login_required
@account_initialization_required
def put(self, member_id):
tenant_id, account_id = _current_ids()
request = _payload(_ReplaceMemberRolesRequest)
@ -638,6 +586,6 @@ class RBACMemberRolesApi(Resource):
tenant_id,
account_id,
str(member_id),
role_keys=list(request.role_keys),
role_ids=list(request.role_ids),
)
)

View File

@ -80,6 +80,7 @@ app_detail_fields = {
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
"permission_keys": fields.List(fields.String),
}
prompt_config_fields = {
@ -117,6 +118,7 @@ app_partial_fields = {
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
"permission_keys": fields.List(fields.String),
}
@ -197,6 +199,7 @@ app_detail_fields_with_site = {
"deleted_tools": fields.List(fields.Nested(deleted_tool_fields)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
"permission_keys": fields.List(fields.String),
"site": fields.Nested(site_fields),
}

View File

@ -11,6 +11,7 @@ dataset_fields = {
"indexing_technique": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"permission_keys": fields.List(fields.String),
}
reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String}
@ -107,6 +108,7 @@ dataset_detail_fields = {
"total_available_documents": fields.Integer,
"enable_api": fields.Boolean,
"is_multimodal": fields.Boolean,
"permission_keys": fields.List(fields.String),
}
file_info_fields = {

View File

@ -61,7 +61,6 @@ class RBACRole(_RBACModel):
tenant_id: str | None = None
type: str
category: str = ""
role_key: str
name: str
description: str = ""
is_builtin: bool = False
@ -88,7 +87,7 @@ class AccessPolicyRoleBinding(_RBACModel):
access_policy_id: str
resource_type: str
resource_id: str = ""
role_key: str
role_id: str
created_at: int = 0
@ -104,7 +103,7 @@ class AccessPolicyMemberBinding(_RBACModel):
class AccessMatrixItem(_RBACModel):
policy: AccessPolicy | None = None
role_keys: list[str] = Field(default_factory=list)
role_ids: list[str] = Field(default_factory=list)
account_ids: list[str] = Field(default_factory=list)
@ -120,6 +119,7 @@ class DatasetAccessMatrix(_RBACModel):
class WorkspaceAccessMatrix(_RBACModel):
items: list[AccessMatrixItem] = Field(default_factory=list)
pagination: Pagination | None = None
class RoleBindingsResponse(_RBACModel):
@ -135,6 +135,26 @@ class MemberRolesResponse(_RBACModel):
roles: list[RBACRole] = Field(default_factory=list)
class ResourcePermissionKeys(_RBACModel):
resource_id: str
permission_keys: list[str] = Field(default_factory=list)
class WorkspacePermissionSnapshot(_RBACModel):
permission_keys: list[str] = Field(default_factory=list)
class ResourcePermissionSnapshot(_RBACModel):
default_permission_keys: list[str] = Field(default_factory=list)
overrides: list[ResourcePermissionKeys] = Field(default_factory=list)
class MyPermissionsResponse(_RBACModel):
workspace: WorkspacePermissionSnapshot = Field(default_factory=WorkspacePermissionSnapshot)
app: ResourcePermissionSnapshot = Field(default_factory=ResourcePermissionSnapshot)
dataset: ResourcePermissionSnapshot = Field(default_factory=ResourcePermissionSnapshot)
# ---------- Mutation request models ----------
@ -146,7 +166,6 @@ class RoleMutation(_RBACModel):
"""
name: str
role_key: str
description: str = ""
permission_keys: list[str] = Field(default_factory=list)
type: RBACRoleType = RBACRoleType.WORKSPACE
@ -166,7 +185,8 @@ class AccessPolicyUpdate(_RBACModel):
class ReplaceRoleBindings(_RBACModel):
role_keys: list[str] = Field(default_factory=list)
role_ids: list[str] = Field(default_factory=list)
class ReplaceMemberBindings(_RBACModel):
@ -594,22 +614,34 @@ class RBACService:
# ------------------------------------------------------------------
class WorkspaceAccess:
@staticmethod
def app_matrix(tenant_id: str, account_id: str | None = None) -> WorkspaceAccessMatrix:
def app_matrix(
tenant_id: str,
account_id: str | None = None,
*,
options: ListOption | None = None,
) -> WorkspaceAccessMatrix:
data = _inner_call(
"GET",
f"{_INNER_PREFIX}/workspace/apps/access-policy",
tenant_id=tenant_id,
account_id=account_id,
params=(options or ListOption()).to_params() or None,
)
return WorkspaceAccessMatrix.model_validate(data or {})
@staticmethod
def dataset_matrix(tenant_id: str, account_id: str | None = None) -> WorkspaceAccessMatrix:
def dataset_matrix(
tenant_id: str,
account_id: str | None = None,
*,
options: ListOption | None = None,
) -> WorkspaceAccessMatrix:
data = _inner_call(
"GET",
f"{_INNER_PREFIX}/workspace/datasets/access-policy",
tenant_id=tenant_id,
account_id=account_id,
params=(options or ListOption()).to_params() or None,
)
return WorkspaceAccessMatrix.model_validate(data or {})
@ -761,7 +793,7 @@ class RBACService:
tenant_id: str,
account_id: str | None,
member_account_id: str,
role_keys: list[str],
role_ids: list[str],
) -> MemberRolesResponse:
data = _inner_call(
"PUT",
@ -769,6 +801,32 @@ class RBACService:
tenant_id=tenant_id,
account_id=account_id,
params={"account_id": member_account_id},
json={"role_keys": role_keys},
json={"role_ids": role_ids},
)
return MemberRolesResponse.model_validate(data or {})
class MyPermissions:
@staticmethod
def get(
tenant_id: str,
account_id: str | None,
*,
app_id: str | None = None,
dataset_id: str | None = None,
) -> MyPermissionsResponse:
data = _inner_call(
"GET",
f"{_INNER_PREFIX}/my-permissions",
tenant_id=tenant_id,
account_id=account_id,
params={
k: v
for k, v in {
"app_id": app_id,
"dataset_id": dataset_id,
}.items()
if v is not None
}
or None,
)
return MyPermissionsResponse.model_validate(data or {})

View File

@ -19,6 +19,7 @@ changes.
from __future__ import annotations
from types import SimpleNamespace
import inspect
from unittest.mock import patch
import pytest
@ -78,7 +79,7 @@ class TestPydanticModels:
missing required fields) trivial `str` fields are not worth asserting.
"""
def test_role_upsert_requires_name_and_key(self):
def test_role_upsert_requires_name(self):
with pytest.raises(ValidationError):
rbac_mod._RoleUpsertRequest.model_validate({})
@ -86,13 +87,11 @@ class TestPydanticModels:
payload = rbac_mod._RoleUpsertRequest.model_validate(
{
"name": "Owner",
"role_key": "workspace.owner",
"description": "full access",
"permission_keys": ["workspace.member.manage"],
}
)
mutation = payload.to_mutation()
assert mutation.role_key == "workspace.owner"
assert mutation.description == "full access"
assert mutation.permission_keys == ["workspace.member.manage"]
@ -113,12 +112,98 @@ class TestPydanticModels:
def test_replace_role_bindings_defaults_empty(self):
parsed = rbac_mod._ReplaceRoleBindingsRequest.model_validate({})
assert parsed.role_keys == []
assert parsed.role_ids == []
def test_pagination_query_accepts_page_and_limit_aliases(self):
parsed = rbac_mod._PaginationQuery.model_validate({"page": 3, "limit": 25, "reverse": True})
assert parsed.page_number == 3
assert parsed.results_per_page == 25
assert parsed.reverse is True
def test_pagination_query_accepts_legacy_inner_names(self):
parsed = rbac_mod._PaginationQuery.model_validate(
{"page_number": 4, "results_per_page": 30, "reverse": False}
)
assert parsed.page_number == 4
assert parsed.results_per_page == 30
assert parsed.reverse is False
class TestPaginationMapping:
def test_roles_get_forwards_outer_pagination_params(self, app):
with (
app.test_request_context("/workspaces/current/rbac/roles?page=2&limit=50&reverse=true"),
_enabled(True),
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
patch("controllers.console.workspace.rbac.svc.RBACService.Roles.list") as mock_list,
patch("controllers.console.workspace.rbac._dump", return_value={}),
):
inspect.unwrap(rbac_mod.RBACRolesApi.get)(rbac_mod.RBACRolesApi())
_, kwargs = mock_list.call_args
options = kwargs["options"]
assert options.page_number == 2
assert options.results_per_page == 50
assert options.reverse is True
def test_access_policies_get_forwards_outer_pagination_params(self, app):
with (
app.test_request_context(
"/workspaces/current/rbac/access-policies?resource_type=app&page=3&limit=25&reverse=false"
),
_enabled(True),
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
patch("controllers.console.workspace.rbac.svc.RBACService.AccessPolicies.list") as mock_list,
patch("controllers.console.workspace.rbac._dump", return_value={}),
):
inspect.unwrap(rbac_mod.RBACAccessPoliciesApi.get)(rbac_mod.RBACAccessPoliciesApi())
_, kwargs = mock_list.call_args
assert kwargs["resource_type"] == "app"
options = kwargs["options"]
assert options.page_number == 3
assert options.results_per_page == 25
assert options.reverse is False
def test_workspace_app_matrix_forwards_outer_pagination_params(self, app):
with (
app.test_request_context("/workspaces/current/rbac/workspace/apps/access-policy?page=4&limit=10"),
_enabled(True),
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
patch("controllers.console.workspace.rbac.svc.RBACService.WorkspaceAccess.app_matrix") as mock_list,
patch("controllers.console.workspace.rbac._dump", return_value={}),
):
inspect.unwrap(rbac_mod.RBACWorkspaceAppMatrixApi.get)(rbac_mod.RBACWorkspaceAppMatrixApi())
_, kwargs = mock_list.call_args
options = kwargs["options"]
assert options.page_number == 4
assert options.results_per_page == 10
assert options.reverse is None
def test_workspace_dataset_matrix_forwards_outer_pagination_params(self, app):
with (
app.test_request_context(
"/workspaces/current/rbac/workspace/datasets/access-policy?page=5&limit=15&reverse=true"
),
_enabled(True),
patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")),
patch("controllers.console.workspace.rbac.svc.RBACService.WorkspaceAccess.dataset_matrix")
as mock_list,
patch("controllers.console.workspace.rbac._dump", return_value={}),
):
inspect.unwrap(rbac_mod.RBACWorkspaceDatasetMatrixApi.get)(rbac_mod.RBACWorkspaceDatasetMatrixApi())
_, kwargs = mock_list.call_args
options = kwargs["options"]
assert options.page_number == 5
assert options.results_per_page == 15
assert options.reverse is True
class TestDumpHelper:
def test_dump_returns_plain_dict(self):
role = rbac_mod.svc.RBACRole(id="role-1", type="workspace", role_key="workspace.owner", name="Owner")
role = rbac_mod.svc.RBACRole(id="role-1", type="workspace", name="Owner")
dumped = rbac_mod._dump(role)
assert isinstance(dumped, dict)
assert dumped["role_key"] == "workspace.owner"
assert "role_id" not in dumped

View File

@ -69,7 +69,6 @@ class TestRoles:
"tenant_id": "tenant-1",
"type": "workspace",
"category": "global_custom",
"role_key": "workspace.owner",
"name": "Owner",
"permission_keys": ["workspace.member.manage"],
}
@ -88,7 +87,6 @@ class TestRoles:
assert call.endpoint == "/rbac/roles"
assert call.params == {"page_number": 2, "results_per_page": 50, "reverse": "true"}
assert out.pagination and out.pagination.total_count == 1
assert out.data[0].role_key == "workspace.owner"
def test_list_omits_params_when_default(self, mock_send: MagicMock):
mock_send.return_value = {"data": [], "pagination": None}
@ -96,12 +94,7 @@ class TestRoles:
assert _call_args(mock_send).params is None
def test_get_passes_id_query_param(self, mock_send: MagicMock):
mock_send.return_value = {
"id": "role-1",
"type": "workspace",
"role_key": "workspace.owner",
"name": "Owner",
}
mock_send.return_value = {"id": "role-1", "type": "workspace", "name": "Owner"}
svc.RBACService.Roles.get("tenant-1", "acct-1", "role-1")
call = _call_args(mock_send)
assert call.method == "GET"
@ -109,18 +102,8 @@ class TestRoles:
assert call.params == {"id": "role-1"}
def test_create_sends_body(self, mock_send: MagicMock):
mock_send.return_value = {
"id": "role-1",
"type": "workspace",
"role_key": "workspace.owner",
"name": "Owner",
}
payload = svc.RoleMutation(
name="Owner",
role_key="workspace.owner",
description="full access",
permission_keys=["workspace.member.manage"],
)
mock_send.return_value = {"id": "role-1", "type": "workspace", "name": "Owner"}
payload = svc.RoleMutation(name="Owner", description="full access", permission_keys=["workspace.member.manage"])
svc.RBACService.Roles.create("tenant-1", "acct-1", payload)
call = _call_args(mock_send)
@ -128,27 +111,21 @@ class TestRoles:
assert call.endpoint == "/rbac/roles"
assert call.json == {
"name": "Owner",
"role_key": "workspace.owner",
"description": "full access",
"permission_keys": ["workspace.member.manage"],
"type": "workspace",
}
def test_update_sends_id_param_and_body(self, mock_send: MagicMock):
mock_send.return_value = {
"id": "role-1",
"type": "workspace",
"role_key": "workspace.owner",
"name": "Owner",
}
payload = svc.RoleMutation(name="Owner", role_key="workspace.owner", permission_keys=["x"])
mock_send.return_value = {"id": "role-1", "type": "workspace", "name": "Owner"}
payload = svc.RoleMutation(name="Owner", permission_keys=["x"])
svc.RBACService.Roles.update("tenant-1", "acct-1", "role-1", payload)
call = _call_args(mock_send)
assert call.method == "PUT"
assert call.endpoint == "/rbac/roles/item"
assert call.params == {"id": "role-1"}
assert call.json["role_key"] == "workspace.owner"
assert call.json == {"name": "Owner", "description": "", "permission_keys": ["x"], "type": "workspace"}
def test_delete_uses_delete_method(self, mock_send: MagicMock):
mock_send.return_value = {"message": "success"}
@ -216,13 +193,13 @@ class TestResourceAccess:
def test_app_replace_role_bindings(self, mock_send: MagicMock):
mock_send.return_value = {"data": []}
payload = svc.ReplaceRoleBindings(role_keys=["workspace.owner"])
payload = svc.ReplaceRoleBindings(role_ids=["workspace.owner"])
svc.RBACService.AppAccess.replace_role_bindings("tenant-1", "acct-1", "app-1", "policy-1", payload)
call = _call_args(mock_send)
assert call.method == "PUT"
assert call.endpoint == "/rbac/apps/access-policy/role-bindings"
assert call.params == {"app_id": "app-1", "policy_id": "policy-1"}
assert call.json == {"role_keys": ["workspace.owner"]}
assert call.json == {"role_ids": ["workspace.owner"]}
def test_dataset_replace_member_bindings(self, mock_send: MagicMock):
mock_send.return_value = {"data": []}
@ -239,12 +216,16 @@ class TestResourceAccess:
class TestWorkspaceAccess:
def test_app_matrix(self, mock_send: MagicMock):
mock_send.return_value = {"items": []}
svc.RBACService.WorkspaceAccess.app_matrix("tenant-1")
mock_send.return_value = {"items": [], "pagination": {"total_count": 1, "per_page": 20, "current_page": 2, "total_pages": 1}}
out = svc.RBACService.WorkspaceAccess.app_matrix(
"tenant-1",
options=svc.ListOption(page_number=2, results_per_page=20),
)
call = _call_args(mock_send)
assert call.method == "GET"
assert call.endpoint == "/rbac/workspace/apps/access-policy"
assert call.params is None
assert call.params == {"page_number": 2, "results_per_page": 20}
assert out.pagination and out.pagination.current_page == 2
def test_dataset_matrix(self, mock_send: MagicMock):
mock_send.return_value = {"items": []}
@ -256,7 +237,7 @@ class TestWorkspaceAccess:
def test_dataset_replace_role_bindings(self, mock_send: MagicMock):
mock_send.return_value = {"data": []}
payload = svc.ReplaceRoleBindings(role_keys=["workspace.editor"])
payload = svc.ReplaceRoleBindings(role_ids=["workspace.editor"])
svc.RBACService.WorkspaceAccess.replace_dataset_role_bindings(
"tenant-1", "acct-1", "policy-1", payload
)
@ -264,7 +245,40 @@ class TestWorkspaceAccess:
assert call.method == "PUT"
assert call.endpoint == "/rbac/workspace/datasets/access-policy/role-bindings"
assert call.params == {"policy_id": "policy-1"}
assert call.json == {"role_keys": ["workspace.editor"]}
assert call.json == {"role_ids": ["workspace.editor"]}
class TestMyPermissions:
def test_get_without_payload_uses_get(self, mock_send: MagicMock):
mock_send.return_value = {
"workspace": {"permission_keys": ["workspace.member.manage"]},
"app": {"default_permission_keys": ["app.acl.test_and_run"], "overrides": []},
"dataset": {"default_permission_keys": [], "overrides": []},
}
out = svc.RBACService.MyPermissions.get("tenant-1", "acct-1")
call = _call_args(mock_send)
assert call.method == "GET"
assert call.endpoint == "/rbac/my-permissions"
assert call.json is None
assert call.params is None
assert out.workspace.permission_keys == ["workspace.member.manage"]
def test_get_with_single_resource_filters(self, mock_send: MagicMock):
mock_send.return_value = {
"workspace": {"permission_keys": []},
"app": {"default_permission_keys": [], "overrides": [{"resource_id": "app-1", "permission_keys": ["app.acl.edit"]}]},
"dataset": {"default_permission_keys": [], "overrides": []},
}
out = svc.RBACService.MyPermissions.get("tenant-1", "acct-1", app_id="app-1")
call = _call_args(mock_send)
assert call.method == "GET"
assert call.endpoint == "/rbac/my-permissions"
assert call.params == {"app_id": "app-1"}
assert out.app.overrides[0].resource_id == "app-1"
class TestMemberRoles:
@ -275,7 +289,6 @@ class TestMemberRoles:
{
"id": "role-1",
"type": "workspace",
"role_key": "workspace.member",
"name": "Member",
}
],
@ -286,18 +299,18 @@ class TestMemberRoles:
assert call.endpoint == "/rbac/members/rbac-roles"
assert call.params == {"account_id": "acct-2"}
assert out.account_id == "acct-2"
assert out.roles[0].role_key == "workspace.member"
assert out.roles[0].name == "Member"
def test_replace(self, mock_send: MagicMock):
mock_send.return_value = {"account_id": "acct-2", "roles": []}
svc.RBACService.MemberRoles.replace(
"tenant-1", "acct-1", "acct-2", role_keys=["workspace.owner", "workspace.editor"]
"tenant-1", "acct-1", "acct-2", role_ids=["workspace.owner", "workspace.editor"]
)
call = _call_args(mock_send)
assert call.method == "PUT"
assert call.endpoint == "/rbac/members/rbac-roles"
assert call.params == {"account_id": "acct-2"}
assert call.json == {"role_keys": ["workspace.owner", "workspace.editor"]}
assert call.json == {"role_ids": ["workspace.owner", "workspace.editor"]}
class TestListOption: