From b32ec8741eb8bf8916bce2bc7ad05a666051ae4c Mon Sep 17 00:00:00 2001 From: fatelei Date: Thu, 23 Apr 2026 10:21:36 +0800 Subject: [PATCH] feat: rbac backend api --- api/controllers/console/__init__.py | 2 + api/controllers/console/workspace/rbac.py | 643 +++++++++++++++ api/services/enterprise/base.py | 51 ++ api/services/enterprise/rbac_service.py | 771 ++++++++++++++++++ .../console/workspace/test_rbac.py | 124 +++ .../services/enterprise/test_rbac_service.py | 306 +++++++ 6 files changed, 1897 insertions(+) create mode 100644 api/controllers/console/workspace/rbac.py create mode 100644 api/services/enterprise/rbac_service.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_rbac.py create mode 100644 api/tests/unit_tests/services/enterprise/test_rbac_service.py diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 980e828945..d1b9674ec4 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -132,6 +132,7 @@ from .workspace import ( model_providers, models, plugin, + rbac, tool_providers, trigger_providers, workspace, @@ -199,6 +200,7 @@ __all__ = [ "rag_pipeline_draft_variable", "rag_pipeline_import", "rag_pipeline_workflow", + "rbac", "recommended_app", "saved_message", "setup", diff --git a/api/controllers/console/workspace/rbac.py b/api/controllers/console/workspace/rbac.py new file mode 100644 index 0000000000..172154488d --- /dev/null +++ b/api/controllers/console/workspace/rbac.py @@ -0,0 +1,643 @@ +"""Dify Console controllers that proxy the enterprise RBAC surface. + +Each route here is a thin adapter: it validates the pydantic payload shown in +the screenshots (`Settings > Permissions`, `Settings > Access Rules`, +`App/Knowledge Base Access Config`, and the `Settings > Members` role +assignment dialog), pulls ``tenant_id`` / ``account_id`` from the current +Dify session and forwards to the inner RBAC client defined in +``services/enterprise/rbac_service.py``. The client then calls the +``/inner/api/rbac/*`` endpoints on dify-enterprise over HTTP using the +shared ``Enterprise-Api-Secret-Key`` header. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import Any + +from flask_restx import Resource +from pydantic import BaseModel, ValidationError +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, setup_required +from libs.login import current_account_with_tenant, login_required +from services.enterprise import rbac_service as svc + + +# --------------------------------------------------------------------------- +# Shared helpers. +# --------------------------------------------------------------------------- + + +def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: + """Reject every call when the Dify install is not running in enterprise + mode. The dashboard UI shown in the screenshots is an enterprise-only + feature, so every route here should fail fast (and clearly) in community. + """ + + @wraps(view) + def decorated(*args: P.args, **kwargs: P.kwargs) -> R: + if not dify_config.ENTERPRISE_ENABLED: + raise Forbidden("Enterprise edition is not enabled") + return view(*args, **kwargs) + + return decorated + + +def _current_ids() -> tuple[str, str]: + """Return ``(tenant_id, account_id)`` for the authenticated user, or + raise a 404 when no tenant is associated with the session. + """ + + user, tenant_id = current_account_with_tenant() + if not tenant_id: + raise NotFound("Current workspace not found") + return tenant_id, user.id + + +def _payload(model: type[BaseModel]) -> Any: + """Validate the JSON body against ``model`` or raise ``ValidationError``. + + ``ValidationError`` bubbles up as HTTP 400 thanks to + ``controllers/common/helpers.py`` error handling. + """ + try: + return model.model_validate(console_ns.payload or {}) + except ValidationError as exc: + # Re-raise as-is so the upstream error handler renders a 400. + raise exc + + +def _dump(model: BaseModel) -> dict[str, Any]: + return model.model_dump(mode="json") + + +# --------------------------------------------------------------------------- +# Permission catalogs. +# --------------------------------------------------------------------------- + + +@console_ns.route("/workspaces/current/rbac/role-permissions/catalog") +class RBACWorkspaceCatalogApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.Catalog.workspace(tenant_id, account_id)) + + +@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/app") +class RBACAppCatalogApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.Catalog.app(tenant_id, account_id)) + + +@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/dataset") +class RBACDatasetCatalogApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.Catalog.dataset(tenant_id, account_id)) + + +# --------------------------------------------------------------------------- +# Roles. +# --------------------------------------------------------------------------- + + +class _RoleUpsertRequest(BaseModel): + """Accepts the payload sent by the Create/Edit Role dialog.""" + + name: str + role_key: str + description: str = "" + permission_keys: list[str] = [] + + def to_mutation(self) -> svc.RoleMutation: + return svc.RoleMutation( + name=self.name, + role_key=self.role_key, + description=self.description, + permission_keys=list(self.permission_keys), + ) + + +@console_ns.route("/workspaces/current/rbac/roles") +class RBACRolesApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id, account_id = _current_ids() + options = svc.ListOption() + return _dump(svc.RBACService.Roles.list(tenant_id, account_id, options=options)) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def post(self): + tenant_id, account_id = _current_ids() + request = _payload(_RoleUpsertRequest) + role = svc.RBACService.Roles.create(tenant_id, account_id, request.to_mutation()) + return _dump(role), 201 + + +@console_ns.route("/workspaces/current/rbac/roles/") +class RBACRoleItemApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, role_id): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.Roles.get(tenant_id, account_id, str(role_id))) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, role_id): + tenant_id, account_id = _current_ids() + request = _payload(_RoleUpsertRequest) + role = svc.RBACService.Roles.update(tenant_id, account_id, str(role_id), request.to_mutation()) + return _dump(role) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def delete(self, role_id): + tenant_id, account_id = _current_ids() + svc.RBACService.Roles.delete(tenant_id, account_id, str(role_id)) + return {"result": "success"} + + +# --------------------------------------------------------------------------- +# Access policies (tenant-level permission sets). +# --------------------------------------------------------------------------- + + +class _AccessPolicyCreateRequest(BaseModel): + name: str + resource_type: svc.RBACResourceType + description: str = "" + permission_keys: list[str] = [] + + +class _AccessPolicyUpdateRequest(BaseModel): + name: str + description: str = "" + permission_keys: list[str] = [] + + +@console_ns.route("/workspaces/current/rbac/access-policies") +class RBACAccessPoliciesApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id, account_id = _current_ids() + # `resource_type` is exposed as a query argument so the UI can show + # only app-scoped or only dataset-scoped permission sets. + from flask import request + + resource_type = request.args.get("resource_type") or None + return _dump( + svc.RBACService.AccessPolicies.list( + tenant_id, + account_id, + resource_type=resource_type, + options=svc.ListOption(), + ) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def post(self): + tenant_id, account_id = _current_ids() + request = _payload(_AccessPolicyCreateRequest) + policy = svc.RBACService.AccessPolicies.create( + tenant_id, + account_id, + svc.AccessPolicyCreate( + name=request.name, + resource_type=request.resource_type, + description=request.description, + permission_keys=list(request.permission_keys), + ), + ) + return _dump(policy), 201 + + +@console_ns.route("/workspaces/current/rbac/access-policies/") +class RBACAccessPolicyItemApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, policy_id): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.AccessPolicies.get(tenant_id, account_id, str(policy_id))) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_AccessPolicyUpdateRequest) + policy = svc.RBACService.AccessPolicies.update( + tenant_id, + account_id, + str(policy_id), + svc.AccessPolicyUpdate( + name=request.name, + description=request.description, + permission_keys=list(request.permission_keys), + ), + ) + return _dump(policy) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def delete(self, policy_id): + tenant_id, account_id = _current_ids() + svc.RBACService.AccessPolicies.delete(tenant_id, account_id, str(policy_id)) + return {"result": "success"} + + +@console_ns.route("/workspaces/current/rbac/access-policies//copy") +class RBACAccessPolicyCopyApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def post(self, policy_id): + tenant_id, account_id = _current_ids() + policy = svc.RBACService.AccessPolicies.copy(tenant_id, account_id, str(policy_id)) + return _dump(policy), 201 + + +# --------------------------------------------------------------------------- +# Per-app access (App Access Config). +# --------------------------------------------------------------------------- + + +class _ReplaceRoleBindingsRequest(BaseModel): + role_keys: list[str] = [] + + +class _ReplaceMemberBindingsRequest(BaseModel): + account_ids: list[str] = [] + + +@console_ns.route("/workspaces/current/rbac/apps//access-policy") +class RBACAppMatrixApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.AppAccess.matrix(tenant_id, account_id, str(app_id))) + + +@console_ns.route("/workspaces/current/rbac/apps//access-policies//role-bindings") +class RBACAppRoleBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, app_id, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.AppAccess.list_role_bindings(tenant_id, account_id, str(app_id), str(policy_id)) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, app_id, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceRoleBindingsRequest) + return _dump( + svc.RBACService.AppAccess.replace_role_bindings( + tenant_id, + account_id, + str(app_id), + str(policy_id), + svc.ReplaceRoleBindings(role_keys=list(request.role_keys)), + ) + ) + + +@console_ns.route("/workspaces/current/rbac/apps//access-policies//member-bindings") +class RBACAppMemberBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, app_id, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.AppAccess.list_member_bindings(tenant_id, account_id, str(app_id), str(policy_id)) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, app_id, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceMemberBindingsRequest) + return _dump( + svc.RBACService.AppAccess.replace_member_bindings( + tenant_id, + account_id, + str(app_id), + str(policy_id), + svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), + ) + ) + + +# --------------------------------------------------------------------------- +# Per-dataset access (Knowledge Base Access Config). +# --------------------------------------------------------------------------- + + +@console_ns.route("/workspaces/current/rbac/datasets//access-policy") +class RBACDatasetMatrixApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.DatasetAccess.matrix(tenant_id, account_id, str(dataset_id))) + + +@console_ns.route("/workspaces/current/rbac/datasets//access-policies//role-bindings") +class RBACDatasetRoleBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.DatasetAccess.list_role_bindings( + tenant_id, account_id, str(dataset_id), str(policy_id) + ) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, dataset_id, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceRoleBindingsRequest) + return _dump( + svc.RBACService.DatasetAccess.replace_role_bindings( + tenant_id, + account_id, + str(dataset_id), + str(policy_id), + svc.ReplaceRoleBindings(role_keys=list(request.role_keys)), + ) + ) + + +@console_ns.route( + "/workspaces/current/rbac/datasets//access-policies//member-bindings" +) +class RBACDatasetMemberBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.DatasetAccess.list_member_bindings( + tenant_id, account_id, str(dataset_id), str(policy_id) + ) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, dataset_id, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceMemberBindingsRequest) + return _dump( + svc.RBACService.DatasetAccess.replace_member_bindings( + tenant_id, + account_id, + str(dataset_id), + str(policy_id), + svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), + ) + ) + + +# --------------------------------------------------------------------------- +# Workspace-level access (Settings > Access Rules). +# --------------------------------------------------------------------------- + + +@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy") +class RBACWorkspaceAppMatrixApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id)) + + +@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies//role-bindings") +class RBACWorkspaceAppRoleBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.WorkspaceAccess.list_app_role_bindings(tenant_id, account_id, str(policy_id)) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceRoleBindingsRequest) + return _dump( + svc.RBACService.WorkspaceAccess.replace_app_role_bindings( + tenant_id, + account_id, + str(policy_id), + svc.ReplaceRoleBindings(role_keys=list(request.role_keys)), + ) + ) + + +@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies//member-bindings") +class RBACWorkspaceAppMemberBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.WorkspaceAccess.list_app_member_bindings(tenant_id, account_id, str(policy_id)) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceMemberBindingsRequest) + return _dump( + svc.RBACService.WorkspaceAccess.replace_app_member_bindings( + tenant_id, + account_id, + str(policy_id), + svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), + ) + ) + + +@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy") +class RBACWorkspaceDatasetMatrixApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id)) + + +@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies//role-bindings") +class RBACWorkspaceDatasetRoleBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.WorkspaceAccess.list_dataset_role_bindings(tenant_id, account_id, str(policy_id)) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceRoleBindingsRequest) + return _dump( + svc.RBACService.WorkspaceAccess.replace_dataset_role_bindings( + tenant_id, + account_id, + str(policy_id), + svc.ReplaceRoleBindings(role_keys=list(request.role_keys)), + ) + ) + + +@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies//member-bindings") +class RBACWorkspaceDatasetMemberBindingsApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, policy_id): + tenant_id, account_id = _current_ids() + return _dump( + svc.RBACService.WorkspaceAccess.list_dataset_member_bindings(tenant_id, account_id, str(policy_id)) + ) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, policy_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceMemberBindingsRequest) + return _dump( + svc.RBACService.WorkspaceAccess.replace_dataset_member_bindings( + tenant_id, + account_id, + str(policy_id), + svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), + ) + ) + + +# --------------------------------------------------------------------------- +# Member ↔ role bindings (Settings > Members > Assign roles). +# --------------------------------------------------------------------------- + + +class _ReplaceMemberRolesRequest(BaseModel): + role_keys: list[str] = [] + + +@console_ns.route("/workspaces/current/rbac/members//rbac-roles") +class RBACMemberRolesApi(Resource): + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def get(self, member_id): + tenant_id, account_id = _current_ids() + return _dump(svc.RBACService.MemberRoles.get(tenant_id, account_id, str(member_id))) + + @enterprise_only + @setup_required + @login_required + @account_initialization_required + def put(self, member_id): + tenant_id, account_id = _current_ids() + request = _payload(_ReplaceMemberRolesRequest) + return _dump( + svc.RBACService.MemberRoles.replace( + tenant_id, + account_id, + str(member_id), + role_keys=list(request.role_keys), + ) + ) diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 68835e76d0..1d22da00f4 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -5,6 +5,7 @@ from typing import Any import httpx +from configs import dify_config from core.helper.trace_id_helper import generate_traceparent_header from services.errors.enterprise import ( EnterpriseAPIBadRequestError, @@ -16,6 +17,11 @@ from services.errors.enterprise import ( logger = logging.getLogger(__name__) +# Headers recognised by dify-enterprise's /inner/api/rbac/* endpoints. +# Keep in sync with pkg/enterprise/service/rbac_inner_handlers.go. +INNER_TENANT_ID_HEADER = "X-Inner-Tenant-Id" +INNER_ACCOUNT_ID_HEADER = "X-Inner-Account-Id" + class BaseRequest: proxies: Mapping[str, str] | None = { @@ -49,8 +55,16 @@ class BaseRequest: *, timeout: float | httpx.Timeout | None = None, raise_for_status: bool = False, + extra_headers: Mapping[str, str] | None = None, ) -> Any: headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key} + if extra_headers: + # Explicitly ignore empty values so callers can pass optional + # headers (e.g. `X-Inner-Account-Id`) without having to branch. + for key, value in extra_headers.items(): + if value is None or value == "": + continue + headers[key] = value url = f"{cls.base_url}{endpoint}" mounts = cls._build_mounts() @@ -122,6 +136,43 @@ class EnterpriseRequest(BaseRequest): secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") secret_key_header = "Enterprise-Api-Secret-Key" + @classmethod + def send_inner_rbac_request( + cls, + method: str, + endpoint: str, + *, + tenant_id: str, + account_id: str | None = None, + json: Any | None = None, + params: Mapping[str, Any] | None = None, + timeout: float | httpx.Timeout | None = None, + ) -> Any: + """Call an /inner/api/rbac/* endpoint on dify-enterprise. + + Inner RBAC endpoints require three headers on top of the standard + Enterprise-Api-Secret-Key: the tenant the call targets and (optionally) + the account acting on behalf of the workspace. This helper centralises + both the assertions and the header wiring so callers only have to + supply business payload. + """ + if not dify_config.ENTERPRISE_ENABLED: + raise EnterpriseAPIError("Enterprise edition is not enabled") + if not tenant_id: + raise ValueError("tenant_id must be provided for inner RBAC requests") + + inner_headers: dict[str, str] = {INNER_TENANT_ID_HEADER: tenant_id} + if account_id: + inner_headers[INNER_ACCOUNT_ID_HEADER] = account_id + return cls.send_request( + method, + endpoint, + json=json, + params=params, + timeout=timeout, + extra_headers=inner_headers, + ) + class EnterprisePluginManagerRequest(BaseRequest): base_url = os.environ.get("ENTERPRISE_PLUGIN_MANAGER_API_URL", "ENTERPRISE_PLUGIN_MANAGER_API_URL") diff --git a/api/services/enterprise/rbac_service.py b/api/services/enterprise/rbac_service.py new file mode 100644 index 0000000000..8b1b3fc25d --- /dev/null +++ b/api/services/enterprise/rbac_service.py @@ -0,0 +1,771 @@ +from __future__ import annotations + +from enum import StrEnum +from typing import Any, Generic, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + +from services.enterprise.base import EnterpriseRequest + +T = TypeVar("T") + + +class _RBACModel(BaseModel): + model_config = ConfigDict(populate_by_name=True, extra="ignore") + + +class Pagination(_RBACModel): + total_count: int = 0 + per_page: int = 0 + current_page: int = 0 + total_pages: int = 0 + + +class Paginated(_RBACModel, Generic[T]): + data: list[T] = Field(default_factory=list) + pagination: Pagination | None = None + + +class RBACResourceType(StrEnum): + """Resource types understood by access policies.""" + + APP = "app" + DATASET = "dataset" + + +class RBACRoleType(StrEnum): + """The only concrete role type after the access-policy refactor.""" + + WORKSPACE = "workspace" + + +class PermissionCatalogItem(_RBACModel): + key: str + name: str + description: str = "" + + +class PermissionCatalogGroup(_RBACModel): + group_key: str + group_name: str + description: str = "" + permissions: list[PermissionCatalogItem] = Field(default_factory=list) + + +class PermissionCatalogResponse(_RBACModel): + groups: list[PermissionCatalogGroup] = Field(default_factory=list) + + +class RBACRole(_RBACModel): + id: str + tenant_id: str | None = None + type: str + category: str = "" + role_key: str + name: str + description: str = "" + is_builtin: bool = False + permission_keys: list[str] = Field(default_factory=list) + + +class AccessPolicy(_RBACModel): + id: str + tenant_id: str = "" + resource_type: str + policy_key: str = "" + name: str + description: str = "" + permission_keys: list[str] = Field(default_factory=list) + is_builtin: bool = False + category: str = "" + created_at: int = 0 + updated_at: int = 0 + + +class AccessPolicyRoleBinding(_RBACModel): + id: str + tenant_id: str = "" + access_policy_id: str + resource_type: str + resource_id: str = "" + role_key: str + created_at: int = 0 + + +class AccessPolicyMemberBinding(_RBACModel): + id: str + tenant_id: str = "" + access_policy_id: str + resource_type: str + resource_id: str = "" + account_id: str + created_at: int = 0 + + +class AccessMatrixItem(_RBACModel): + policy: AccessPolicy | None = None + role_keys: list[str] = Field(default_factory=list) + account_ids: list[str] = Field(default_factory=list) + + +class ResourceAccessMatrix(_RBACModel): + resource_type: str + resource_id: str = "" + items: list[AccessMatrixItem] = Field(default_factory=list) + + +class WorkspaceAccessMatrix(_RBACModel): + resource_type: str + items: list[AccessMatrixItem] = Field(default_factory=list) + + +class RoleBindingsResponse(_RBACModel): + data: list[AccessPolicyRoleBinding] = Field(default_factory=list) + + +class MemberBindingsResponse(_RBACModel): + data: list[AccessPolicyMemberBinding] = Field(default_factory=list) + + +class MemberRolesResponse(_RBACModel): + account_id: str + roles: list[RBACRole] = Field(default_factory=list) + + +# ---------- Mutation request models ---------- + + +class RoleMutation(_RBACModel): + """Payload shared by role create & update. + + ``type`` defaults to ``workspace`` because that is the only concrete role + type supported by the enterprise backend today (see biz.RBACRoleType). + """ + + name: str + role_key: str + description: str = "" + permission_keys: list[str] = Field(default_factory=list) + type: RBACRoleType = RBACRoleType.WORKSPACE + + +class AccessPolicyCreate(_RBACModel): + name: str + resource_type: RBACResourceType + description: str = "" + permission_keys: list[str] = Field(default_factory=list) + + +class AccessPolicyUpdate(_RBACModel): + name: str + description: str = "" + permission_keys: list[str] = Field(default_factory=list) + + +class ReplaceRoleBindings(_RBACModel): + role_keys: list[str] = Field(default_factory=list) + + +class ReplaceMemberBindings(_RBACModel): + account_ids: list[str] = Field(default_factory=list) + + +class ListOption(_RBACModel): + page_number: int | None = None + results_per_page: int | None = None + reverse: bool | None = None + + def to_params(self, extra: dict[str, Any] | None = None) -> dict[str, Any]: + params: dict[str, Any] = {} + if self.page_number is not None: + params["page_number"] = self.page_number + if self.results_per_page is not None: + params["results_per_page"] = self.results_per_page + if self.reverse is not None: + # httpx renders `True` as the string "True"; we want the inner + # handler to match on the lowercase form it compares against. + params["reverse"] = "true" if self.reverse else "false" + if extra: + params.update({k: v for k, v in extra.items() if v is not None}) + return params + + +_INNER_PREFIX = "/rbac" + + +def _inner_call( + method: str, + endpoint: str, + *, + tenant_id: str, + account_id: str | None = None, + json: Any | None = None, + params: dict[str, Any] | None = None, +) -> Any: + """Thin wrapper around `EnterpriseRequest.send_inner_rbac_request`. + + Kept as a module-level helper (rather than a nested-class method) so that + unit tests can monkey-patch this single entry point instead of every + individual `Roles.*`, `AccessPolicies.*`, … method. + """ + return EnterpriseRequest.send_inner_rbac_request( + method, + endpoint, + tenant_id=tenant_id, + account_id=account_id, + json=json, + params=params, + ) + + +class RBACService: + """Single entry point grouping every inner RBAC call by feature area. + + Each nested class keeps the classmethods tightly scoped to one URL family + so call sites read naturally (e.g. ``RBACService.Roles.create(tenant_id, + account_id, payload)``). + """ + + # ------------------------------------------------------------------ + # Permission catalog (screenshot 3: 新增/编辑角色 弹窗内的权限列表). + # ------------------------------------------------------------------ + class Catalog: + @staticmethod + def workspace(tenant_id: str, account_id: str | None = None) -> PermissionCatalogResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/role-permissions/catalog", + tenant_id=tenant_id, + account_id=account_id, + ) + return PermissionCatalogResponse.model_validate(data or {}) + + @staticmethod + def app(tenant_id: str, account_id: str | None = None) -> PermissionCatalogResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/role-permissions/catalog/app", + tenant_id=tenant_id, + account_id=account_id, + ) + return PermissionCatalogResponse.model_validate(data or {}) + + @staticmethod + def dataset(tenant_id: str, account_id: str | None = None) -> PermissionCatalogResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/role-permissions/catalog/dataset", + tenant_id=tenant_id, + account_id=account_id, + ) + return PermissionCatalogResponse.model_validate(data or {}) + + # ------------------------------------------------------------------ + # Role CRUD (Settings > Permissions). + # ------------------------------------------------------------------ + class Roles: + @staticmethod + def list( + tenant_id: str, + account_id: str | None = None, + *, + options: ListOption | None = None, + ) -> Paginated[RBACRole]: + params = (options or ListOption()).to_params() + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/roles", + tenant_id=tenant_id, + account_id=account_id, + params=params or None, + ) + data = data or {} + return Paginated[RBACRole]( + data=[RBACRole.model_validate(item) for item in data.get("data") or []], + pagination=Pagination.model_validate(data["pagination"]) if data.get("pagination") else None, + ) + + @staticmethod + def get(tenant_id: str, account_id: str | None, role_id: str) -> RBACRole: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/roles/item", + tenant_id=tenant_id, + account_id=account_id, + params={"id": role_id}, + ) + return RBACRole.model_validate(data or {}) + + @staticmethod + def create(tenant_id: str, account_id: str | None, payload: RoleMutation) -> RBACRole: + data = _inner_call( + "POST", + f"{_INNER_PREFIX}/roles", + tenant_id=tenant_id, + account_id=account_id, + json=payload.model_dump(mode="json"), + ) + return RBACRole.model_validate(data or {}) + + @staticmethod + def update(tenant_id: str, account_id: str | None, role_id: str, payload: RoleMutation) -> RBACRole: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/roles/item", + tenant_id=tenant_id, + account_id=account_id, + params={"id": role_id}, + json=payload.model_dump(mode="json"), + ) + return RBACRole.model_validate(data or {}) + + @staticmethod + def delete(tenant_id: str, account_id: str | None, role_id: str) -> None: + _inner_call( + "DELETE", + f"{_INNER_PREFIX}/roles/item", + tenant_id=tenant_id, + account_id=account_id, + params={"id": role_id}, + ) + + # ------------------------------------------------------------------ + # Access policies (Settings > Access Rules: create/edit permission sets). + # ------------------------------------------------------------------ + class AccessPolicies: + @staticmethod + def list( + tenant_id: str, + account_id: str | None = None, + *, + resource_type: RBACResourceType | str | None = None, + options: ListOption | None = None, + ) -> Paginated[AccessPolicy]: + extra: dict[str, Any] = {} + if resource_type is not None: + extra["resource_type"] = ( + resource_type.value if isinstance(resource_type, RBACResourceType) else resource_type + ) + params = (options or ListOption()).to_params(extra) + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/access-policies", + tenant_id=tenant_id, + account_id=account_id, + params=params or None, + ) + data = data or {} + return Paginated[AccessPolicy]( + data=[AccessPolicy.model_validate(item) for item in data.get("data") or []], + pagination=Pagination.model_validate(data["pagination"]) if data.get("pagination") else None, + ) + + @staticmethod + def get(tenant_id: str, account_id: str | None, policy_id: str) -> AccessPolicy: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/access-policies/item", + tenant_id=tenant_id, + account_id=account_id, + params={"id": policy_id}, + ) + return AccessPolicy.model_validate(data or {}) + + @staticmethod + def create(tenant_id: str, account_id: str | None, payload: AccessPolicyCreate) -> AccessPolicy: + data = _inner_call( + "POST", + f"{_INNER_PREFIX}/access-policies", + tenant_id=tenant_id, + account_id=account_id, + json=payload.model_dump(mode="json"), + ) + return AccessPolicy.model_validate(data or {}) + + @staticmethod + def update( + tenant_id: str, + account_id: str | None, + policy_id: str, + payload: AccessPolicyUpdate, + ) -> AccessPolicy: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/access-policies/item", + tenant_id=tenant_id, + account_id=account_id, + params={"id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return AccessPolicy.model_validate(data or {}) + + @staticmethod + def copy(tenant_id: str, account_id: str | None, policy_id: str) -> AccessPolicy: + data = _inner_call( + "POST", + f"{_INNER_PREFIX}/access-policies/copy", + tenant_id=tenant_id, + account_id=account_id, + params={"id": policy_id}, + ) + return AccessPolicy.model_validate(data or {}) + + @staticmethod + def delete(tenant_id: str, account_id: str | None, policy_id: str) -> None: + _inner_call( + "DELETE", + f"{_INNER_PREFIX}/access-policies/item", + tenant_id=tenant_id, + account_id=account_id, + params={"id": policy_id}, + ) + + # ------------------------------------------------------------------ + # Per-app access (screenshot 1: App Access Config). + # ------------------------------------------------------------------ + class AppAccess: + @staticmethod + def matrix(tenant_id: str, account_id: str | None, app_id: str) -> ResourceAccessMatrix: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/apps/access-policy", + tenant_id=tenant_id, + account_id=account_id, + params={"app_id": app_id}, + ) + return ResourceAccessMatrix.model_validate(data or {}) + + @staticmethod + def list_role_bindings( + tenant_id: str, + account_id: str | None, + app_id: str, + policy_id: str, + ) -> RoleBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/apps/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"app_id": app_id, "policy_id": policy_id}, + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_role_bindings( + tenant_id: str, + account_id: str | None, + app_id: str, + policy_id: str, + payload: ReplaceRoleBindings, + ) -> RoleBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/apps/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"app_id": app_id, "policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def list_member_bindings( + tenant_id: str, + account_id: str | None, + app_id: str, + policy_id: str, + ) -> MemberBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/apps/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"app_id": app_id, "policy_id": policy_id}, + ) + return MemberBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_member_bindings( + tenant_id: str, + account_id: str | None, + app_id: str, + policy_id: str, + payload: ReplaceMemberBindings, + ) -> MemberBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/apps/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"app_id": app_id, "policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return MemberBindingsResponse.model_validate(data or {}) + + # ------------------------------------------------------------------ + # Per-dataset access (screenshot 1: Knowledge Base Access Config). + # ------------------------------------------------------------------ + class DatasetAccess: + @staticmethod + def matrix(tenant_id: str, account_id: str | None, dataset_id: str) -> ResourceAccessMatrix: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/datasets/access-policy", + tenant_id=tenant_id, + account_id=account_id, + params={"dataset_id": dataset_id}, + ) + return ResourceAccessMatrix.model_validate(data or {}) + + @staticmethod + def list_role_bindings( + tenant_id: str, + account_id: str | None, + dataset_id: str, + policy_id: str, + ) -> RoleBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/datasets/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"dataset_id": dataset_id, "policy_id": policy_id}, + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_role_bindings( + tenant_id: str, + account_id: str | None, + dataset_id: str, + policy_id: str, + payload: ReplaceRoleBindings, + ) -> RoleBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/datasets/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"dataset_id": dataset_id, "policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def list_member_bindings( + tenant_id: str, + account_id: str | None, + dataset_id: str, + policy_id: str, + ) -> MemberBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/datasets/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"dataset_id": dataset_id, "policy_id": policy_id}, + ) + return MemberBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_member_bindings( + tenant_id: str, + account_id: str | None, + dataset_id: str, + policy_id: str, + payload: ReplaceMemberBindings, + ) -> MemberBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/datasets/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"dataset_id": dataset_id, "policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return MemberBindingsResponse.model_validate(data or {}) + + # ------------------------------------------------------------------ + # Workspace-level access (screenshot 2: Settings > Access Rules). + # ------------------------------------------------------------------ + class WorkspaceAccess: + @staticmethod + def app_matrix(tenant_id: str, account_id: str | None = None) -> WorkspaceAccessMatrix: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/workspace/apps/access-policy", + tenant_id=tenant_id, + account_id=account_id, + ) + return WorkspaceAccessMatrix.model_validate(data or {}) + + @staticmethod + def dataset_matrix(tenant_id: str, account_id: str | None = None) -> WorkspaceAccessMatrix: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/workspace/datasets/access-policy", + tenant_id=tenant_id, + account_id=account_id, + ) + return WorkspaceAccessMatrix.model_validate(data or {}) + + @staticmethod + def list_app_role_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + ) -> RoleBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/workspace/apps/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_app_role_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + payload: ReplaceRoleBindings, + ) -> RoleBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/workspace/apps/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def list_app_member_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + ) -> MemberBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/workspace/apps/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + ) + return MemberBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_app_member_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + payload: ReplaceMemberBindings, + ) -> MemberBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/workspace/apps/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return MemberBindingsResponse.model_validate(data or {}) + + @staticmethod + def list_dataset_role_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + ) -> RoleBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/workspace/datasets/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_dataset_role_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + payload: ReplaceRoleBindings, + ) -> RoleBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/workspace/datasets/access-policy/role-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return RoleBindingsResponse.model_validate(data or {}) + + @staticmethod + def list_dataset_member_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + ) -> MemberBindingsResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/workspace/datasets/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + ) + return MemberBindingsResponse.model_validate(data or {}) + + @staticmethod + def replace_dataset_member_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + payload: ReplaceMemberBindings, + ) -> MemberBindingsResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/workspace/datasets/access-policy/member-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return MemberBindingsResponse.model_validate(data or {}) + + # ------------------------------------------------------------------ + # Member ↔ role bindings (screenshot 3: Settings > Members > Assign roles). + # ------------------------------------------------------------------ + class MemberRoles: + @staticmethod + def get(tenant_id: str, account_id: str | None, member_account_id: str) -> MemberRolesResponse: + data = _inner_call( + "GET", + f"{_INNER_PREFIX}/members/rbac-roles", + tenant_id=tenant_id, + account_id=account_id, + params={"account_id": member_account_id}, + ) + return MemberRolesResponse.model_validate(data or {}) + + @staticmethod + def replace( + tenant_id: str, + account_id: str | None, + member_account_id: str, + role_keys: list[str], + ) -> MemberRolesResponse: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/members/rbac-roles", + tenant_id=tenant_id, + account_id=account_id, + params={"account_id": member_account_id}, + json={"role_keys": role_keys}, + ) + return MemberRolesResponse.model_validate(data or {}) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_rbac.py b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py new file mode 100644 index 0000000000..664ffe4b0c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py @@ -0,0 +1,124 @@ +"""Controller tests for ``controllers.console.workspace.rbac``. + +The controllers here are thin: almost every non-trivial behaviour lives in +``services.enterprise.rbac_service`` (covered by its own suite). These tests +therefore focus on the three Flask-layer concerns the service layer cannot +exercise: + +* ``enterprise_only`` rejects community-edition calls with 403 (it is the + outermost decorator, so it fires before any auth middleware). +* ``_current_ids`` raises 404 when the session has no tenant. +* The pydantic request models accept / reject bodies as expected. + +We explicitly avoid "happy-path" integration tests through the full +decorator stack — those belong in e2e tests where a real Dify session is +available — to keep this suite fast and resilient to ancillary auth wiring +changes. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest +from flask import Flask +from pydantic import ValidationError +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.console.workspace import rbac as rbac_mod + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +def _enabled(enabled: bool): + return patch("controllers.console.workspace.rbac.dify_config.ENTERPRISE_ENABLED", enabled) + + +class TestEnterpriseGate: + """``enterprise_only`` is the outermost decorator on every resource, so we + can exercise it directly — no auth stubs required. + """ + + def test_catalog_forbidden_when_disabled(self, app): + with app.test_request_context("/workspaces/current/rbac/role-permissions/catalog"), _enabled(False): + with pytest.raises(Forbidden): + rbac_mod.RBACWorkspaceCatalogApi().get() + + def test_roles_post_forbidden_when_disabled(self, app): + with ( + app.test_request_context("/workspaces/current/rbac/roles", method="POST", json={}), + _enabled(False), + ): + with pytest.raises(Forbidden): + rbac_mod.RBACRolesApi().post() + + +class TestCurrentIds: + def test_rejects_missing_tenant(self): + with patch("controllers.console.workspace.rbac.current_account_with_tenant") as mock_user: + mock_user.return_value = (SimpleNamespace(id="acct-1"), None) + with pytest.raises(NotFound): + rbac_mod._current_ids() + + def test_returns_tuple(self): + with patch("controllers.console.workspace.rbac.current_account_with_tenant") as mock_user: + mock_user.return_value = (SimpleNamespace(id="acct-1"), "tenant-1") + assert rbac_mod._current_ids() == ("tenant-1", "acct-1") + + +class TestPydanticModels: + """The internal `_…Request` models are the contract between the browser + and the controllers. We only check non-obvious branches (enum parsing, + missing required fields) — trivial `str` fields are not worth asserting. + """ + + def test_role_upsert_requires_name_and_key(self): + with pytest.raises(ValidationError): + rbac_mod._RoleUpsertRequest.model_validate({}) + + def test_role_upsert_to_mutation_preserves_fields(self): + payload = rbac_mod._RoleUpsertRequest.model_validate( + { + "name": "Owner", + "role_key": "workspace.owner", + "description": "full access", + "permission_keys": ["workspace.member.manage"], + } + ) + mutation = payload.to_mutation() + assert mutation.role_key == "workspace.owner" + assert mutation.description == "full access" + assert mutation.permission_keys == ["workspace.member.manage"] + + def test_access_policy_create_parses_resource_type_enum(self): + parsed = rbac_mod._AccessPolicyCreateRequest.model_validate( + { + "name": "Full access", + "resource_type": "app", + "description": "", + "permission_keys": [], + } + ) + assert parsed.resource_type is rbac_mod.svc.RBACResourceType.APP + + def test_access_policy_create_rejects_unknown_resource_type(self): + with pytest.raises(ValidationError): + rbac_mod._AccessPolicyCreateRequest.model_validate({"name": "bad", "resource_type": "unknown"}) + + def test_replace_role_bindings_defaults_empty(self): + parsed = rbac_mod._ReplaceRoleBindingsRequest.model_validate({}) + assert parsed.role_keys == [] + + +class TestDumpHelper: + def test_dump_returns_plain_dict(self): + role = rbac_mod.svc.RBACRole(id="role-1", type="workspace", role_key="workspace.owner", name="Owner") + dumped = rbac_mod._dump(role) + assert isinstance(dumped, dict) + assert dumped["role_key"] == "workspace.owner" diff --git a/api/tests/unit_tests/services/enterprise/test_rbac_service.py b/api/tests/unit_tests/services/enterprise/test_rbac_service.py new file mode 100644 index 0000000000..04640c5037 --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_rbac_service.py @@ -0,0 +1,306 @@ +"""Unit tests for services.enterprise.rbac_service. + +The enterprise RBAC client is almost pure glue: each method turns a single +``EnterpriseRequest.send_inner_rbac_request`` call into a pydantic response +model. Rather than spinning up an HTTP server we monkeypatch that helper and +assert on the arguments it received; that catches both routing regressions +(wrong method / wrong path / wrong params) and model-shape regressions in +one place. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from services.enterprise import rbac_service as svc + +MODULE = "services.enterprise.rbac_service" + + +@pytest.fixture +def mock_send(): + with patch(f"{MODULE}.EnterpriseRequest.send_inner_rbac_request") as send: + yield send + + +def _call_args(send: MagicMock) -> SimpleNamespace: + """Return the most recent (method, endpoint, kwargs) sent to the mock.""" + send.assert_called_once() + args, kwargs = send.call_args + return SimpleNamespace(method=args[0], endpoint=args[1], **kwargs) + + +class TestCatalog: + def test_workspace_catalog(self, mock_send: MagicMock): + mock_send.return_value = {"groups": [{"group_key": "workspace", "group_name": "工作空间", "permissions": []}]} + + out = svc.RBACService.Catalog.workspace("tenant-1", account_id="acct-1") + + call = _call_args(mock_send) + assert call.method == "GET" + assert call.endpoint == "/rbac/role-permissions/catalog" + assert call.tenant_id == "tenant-1" + assert call.account_id == "acct-1" + assert call.json is None + assert call.params is None + assert len(out.groups) == 1 + assert out.groups[0].group_key == "workspace" + + def test_app_catalog_endpoint(self, mock_send: MagicMock): + mock_send.return_value = {"groups": []} + svc.RBACService.Catalog.app("tenant-1") + assert mock_send.call_args.args[1] == "/rbac/role-permissions/catalog/app" + + def test_dataset_catalog_endpoint(self, mock_send: MagicMock): + mock_send.return_value = {"groups": []} + svc.RBACService.Catalog.dataset("tenant-1") + assert mock_send.call_args.args[1] == "/rbac/role-permissions/catalog/dataset" + + +class TestRoles: + def test_list_forwards_pagination_options(self, mock_send: MagicMock): + mock_send.return_value = { + "data": [ + { + "id": "role-1", + "tenant_id": "tenant-1", + "type": "workspace", + "category": "global_custom", + "role_key": "workspace.owner", + "name": "Owner", + "permission_keys": ["workspace.member.manage"], + } + ], + "pagination": {"total_count": 1, "per_page": 20, "current_page": 1, "total_pages": 1}, + } + + out = svc.RBACService.Roles.list( + "tenant-1", + "acct-1", + options=svc.ListOption(page_number=2, results_per_page=50, reverse=True), + ) + + call = _call_args(mock_send) + assert call.method == "GET" + assert call.endpoint == "/rbac/roles" + assert call.params == {"page_number": 2, "results_per_page": 50, "reverse": "true"} + assert out.pagination and out.pagination.total_count == 1 + assert out.data[0].role_key == "workspace.owner" + + def test_list_omits_params_when_default(self, mock_send: MagicMock): + mock_send.return_value = {"data": [], "pagination": None} + svc.RBACService.Roles.list("tenant-1") + assert _call_args(mock_send).params is None + + def test_get_passes_id_query_param(self, mock_send: MagicMock): + mock_send.return_value = { + "id": "role-1", + "type": "workspace", + "role_key": "workspace.owner", + "name": "Owner", + } + svc.RBACService.Roles.get("tenant-1", "acct-1", "role-1") + call = _call_args(mock_send) + assert call.method == "GET" + assert call.endpoint == "/rbac/roles/item" + assert call.params == {"id": "role-1"} + + def test_create_sends_body(self, mock_send: MagicMock): + mock_send.return_value = { + "id": "role-1", + "type": "workspace", + "role_key": "workspace.owner", + "name": "Owner", + } + payload = svc.RoleMutation( + name="Owner", + role_key="workspace.owner", + description="full access", + permission_keys=["workspace.member.manage"], + ) + svc.RBACService.Roles.create("tenant-1", "acct-1", payload) + + call = _call_args(mock_send) + assert call.method == "POST" + assert call.endpoint == "/rbac/roles" + assert call.json == { + "name": "Owner", + "role_key": "workspace.owner", + "description": "full access", + "permission_keys": ["workspace.member.manage"], + "type": "workspace", + } + + def test_update_sends_id_param_and_body(self, mock_send: MagicMock): + mock_send.return_value = { + "id": "role-1", + "type": "workspace", + "role_key": "workspace.owner", + "name": "Owner", + } + payload = svc.RoleMutation(name="Owner", role_key="workspace.owner", permission_keys=["x"]) + svc.RBACService.Roles.update("tenant-1", "acct-1", "role-1", payload) + + call = _call_args(mock_send) + assert call.method == "PUT" + assert call.endpoint == "/rbac/roles/item" + assert call.params == {"id": "role-1"} + assert call.json["role_key"] == "workspace.owner" + + def test_delete_uses_delete_method(self, mock_send: MagicMock): + mock_send.return_value = {"message": "success"} + svc.RBACService.Roles.delete("tenant-1", None, "role-1") + + call = _call_args(mock_send) + assert call.method == "DELETE" + assert call.endpoint == "/rbac/roles/item" + assert call.params == {"id": "role-1"} + assert call.account_id is None + + +class TestAccessPolicies: + def test_list_filters_by_resource_type(self, mock_send: MagicMock): + mock_send.return_value = {"data": [], "pagination": None} + svc.RBACService.AccessPolicies.list( + "tenant-1", + "acct-1", + resource_type=svc.RBACResourceType.APP, + options=svc.ListOption(page_number=1), + ) + call = _call_args(mock_send) + assert call.endpoint == "/rbac/access-policies" + assert call.params == {"page_number": 1, "resource_type": "app"} + + def test_copy_sends_post_with_id_param(self, mock_send: MagicMock): + mock_send.return_value = { + "id": "policy-1-copy", + "resource_type": "app", + "name": "Full access copy", + } + svc.RBACService.AccessPolicies.copy("tenant-1", "acct-1", "policy-1") + call = _call_args(mock_send) + assert call.method == "POST" + assert call.endpoint == "/rbac/access-policies/copy" + assert call.params == {"id": "policy-1"} + + def test_create_serialises_resource_type_enum(self, mock_send: MagicMock): + mock_send.return_value = {"id": "policy-1", "resource_type": "dataset", "name": "KB only"} + payload = svc.AccessPolicyCreate( + name="KB only", + resource_type=svc.RBACResourceType.DATASET, + permission_keys=["dataset.acl.readonly"], + ) + svc.RBACService.AccessPolicies.create("tenant-1", "acct-1", payload) + call = _call_args(mock_send) + assert call.method == "POST" + assert call.json == { + "name": "KB only", + "resource_type": "dataset", + "description": "", + "permission_keys": ["dataset.acl.readonly"], + } + + +class TestResourceAccess: + def test_app_matrix(self, mock_send: MagicMock): + mock_send.return_value = {"resource_type": "app", "resource_id": "app-1", "items": []} + svc.RBACService.AppAccess.matrix("tenant-1", "acct-1", "app-1") + call = _call_args(mock_send) + assert call.method == "GET" + assert call.endpoint == "/rbac/apps/access-policy" + assert call.params == {"app_id": "app-1"} + + def test_app_replace_role_bindings(self, mock_send: MagicMock): + mock_send.return_value = {"data": []} + payload = svc.ReplaceRoleBindings(role_keys=["workspace.owner"]) + svc.RBACService.AppAccess.replace_role_bindings("tenant-1", "acct-1", "app-1", "policy-1", payload) + call = _call_args(mock_send) + assert call.method == "PUT" + assert call.endpoint == "/rbac/apps/access-policy/role-bindings" + assert call.params == {"app_id": "app-1", "policy_id": "policy-1"} + assert call.json == {"role_keys": ["workspace.owner"]} + + def test_dataset_replace_member_bindings(self, mock_send: MagicMock): + mock_send.return_value = {"data": []} + payload = svc.ReplaceMemberBindings(account_ids=["acct-2"]) + svc.RBACService.DatasetAccess.replace_member_bindings( + "tenant-1", "acct-1", "ds-1", "policy-1", payload + ) + call = _call_args(mock_send) + assert call.method == "PUT" + assert call.endpoint == "/rbac/datasets/access-policy/member-bindings" + assert call.params == {"dataset_id": "ds-1", "policy_id": "policy-1"} + assert call.json == {"account_ids": ["acct-2"]} + + +class TestWorkspaceAccess: + def test_app_matrix(self, mock_send: MagicMock): + mock_send.return_value = {"resource_type": "app", "items": []} + svc.RBACService.WorkspaceAccess.app_matrix("tenant-1") + call = _call_args(mock_send) + assert call.method == "GET" + assert call.endpoint == "/rbac/workspace/apps/access-policy" + assert call.params is None + + def test_dataset_replace_role_bindings(self, mock_send: MagicMock): + mock_send.return_value = {"data": []} + payload = svc.ReplaceRoleBindings(role_keys=["workspace.editor"]) + svc.RBACService.WorkspaceAccess.replace_dataset_role_bindings( + "tenant-1", "acct-1", "policy-1", payload + ) + call = _call_args(mock_send) + assert call.method == "PUT" + assert call.endpoint == "/rbac/workspace/datasets/access-policy/role-bindings" + assert call.params == {"policy_id": "policy-1"} + assert call.json == {"role_keys": ["workspace.editor"]} + + +class TestMemberRoles: + def test_get(self, mock_send: MagicMock): + mock_send.return_value = { + "account_id": "acct-2", + "roles": [ + { + "id": "role-1", + "type": "workspace", + "role_key": "workspace.member", + "name": "Member", + } + ], + } + out = svc.RBACService.MemberRoles.get("tenant-1", "acct-1", "acct-2") + call = _call_args(mock_send) + assert call.method == "GET" + assert call.endpoint == "/rbac/members/rbac-roles" + assert call.params == {"account_id": "acct-2"} + assert out.account_id == "acct-2" + assert out.roles[0].role_key == "workspace.member" + + def test_replace(self, mock_send: MagicMock): + mock_send.return_value = {"account_id": "acct-2", "roles": []} + svc.RBACService.MemberRoles.replace( + "tenant-1", "acct-1", "acct-2", role_keys=["workspace.owner", "workspace.editor"] + ) + call = _call_args(mock_send) + assert call.method == "PUT" + assert call.endpoint == "/rbac/members/rbac-roles" + assert call.params == {"account_id": "acct-2"} + assert call.json == {"role_keys": ["workspace.owner", "workspace.editor"]} + + +class TestListOption: + def test_empty_produces_empty_params(self): + assert svc.ListOption().to_params() == {} + + def test_reverse_serialises_as_lowercase_bool(self): + assert svc.ListOption(reverse=False).to_params()["reverse"] == "false" + assert svc.ListOption(reverse=True).to_params()["reverse"] == "true" + + def test_extra_overrides_merge(self): + assert svc.ListOption(page_number=1).to_params({"resource_type": "app", "skip": None}) == { + "page_number": 1, + "resource_type": "app", + }