diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index 8a6a921a4e..f5ae5db558 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -23,6 +23,11 @@ class EnterpriseFeatureConfig(BaseSettings): ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 ) + RBAC_ENABLED: bool = Field( + description="Enable enterprise RBAC APIs. When disabled, compatibility responses fall back to legacy roles.", + default=False, + ) + class EnterpriseTelemetryConfig(BaseSettings): """ diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index e3bf4c95b8..8e21506d66 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -30,6 +30,7 @@ from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole from services.account_service import AccountService, RegisterService, TenantService +from services.enterprise import rbac_service as enterprise_rbac_service from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService @@ -72,6 +73,19 @@ register_enum_models(console_ns, TenantAccountRole) register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) +def _serialize_member_roles(current_role: str | None, member_role_ids: list[str]) -> list[str]: + if member_role_ids: + return member_role_ids + if current_role: + return [current_role] + return [] + + +def _normalize_enum_value(value: object) -> str: + normalized = getattr(value, "value", value) + return str(normalized) if normalized is not None else "" + + @console_ns.route("/workspaces/current/members") class MemberListApi(Resource): """List all members of current tenant.""" @@ -85,7 +99,36 @@ class MemberListApi(Resource): if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) - member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + if dify_config.RBAC_ENABLED: + member_ids = [member.id for member in members] + member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get( + str(current_user.current_tenant.id), + current_user.id, + member_ids, + ) + roles_map = {item.account_id: [role.id for role in item.roles] for item in member_roles} + else: + roles_map = {} + + serialized_members = [] + for member in members: + current_role = _normalize_enum_value(member.current_role) + serialized_members.append( + { + "id": member.id, + "name": member.name, + "email": member.email, + "avatar": member.avatar, + "last_login_at": member.last_login_at, + "last_active_at": member.last_active_at, + "created_at": member.created_at, + "role": current_role, + "roles": _serialize_member_roles(current_role, roles_map.get(member.id, [])), + "status": _normalize_enum_value(member.status), + } + ) + + member_models = TypeAdapter(list[AccountWithRole]).validate_python(serialized_members) response = AccountWithRoleList(accounts=member_models) return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/workspace/rbac.py b/api/controllers/console/workspace/rbac.py index 939fa25d0e..0a93a456de 100644 --- a/api/controllers/console/workspace/rbac.py +++ b/api/controllers/console/workspace/rbac.py @@ -15,6 +15,89 @@ from libs.login import current_account_with_tenant, login_required from services.enterprise import rbac_service as svc +_LEGACY_WORKSPACE_PERMISSION_KEYS: list[str] = [ + "inviteMembers", + "removeMembers", + "assignRoles", + "workspaceSettings", + "manageBilling", + "transferOwnership", +] + +_LEGACY_APP_PERMISSION_KEYS: list[str] = [ + "createApps", + "editApps", + "useApps", +] + +_LEGACY_DATASET_PERMISSION_KEYS: list[str] = [ + "createDatasets", + "editDatasets", + "manageDatasets", +] + +_LEGACY_ENTERPRISE_PERMISSION_KEYS: list[str] = [ + "workspace.member.manage", + "workspace.settings.manage", + "workspace.billing.manage", + "workspace.owner.transfer", + "app.acl.edit", + "app.acl.test_and_run", + "dataset.acl.edit", +] + +_LEGACY_ROLE_PERMISSION_KEYS: dict[str, list[str]] = { + # These legacy role groups predate the RBAC refactor. The mapping keeps the + # old workspace roles readable through the new RBAC endpoint by translating + # each role into the closest enterprise permission keys that already exist + # in the catalog and tests. + "owner": [ + *_LEGACY_WORKSPACE_PERMISSION_KEYS, + *_LEGACY_APP_PERMISSION_KEYS, + *_LEGACY_DATASET_PERMISSION_KEYS, + *_LEGACY_ENTERPRISE_PERMISSION_KEYS, + ], + "admin": [ + "inviteMembers", + "removeMembers", + "assignRoles", + "workspaceSettings", + "manageBilling", + "workspace.member.manage", + "workspace.settings.manage", + "workspace.billing.manage", + "app.acl.edit", + "app.acl.test_and_run", + "dataset.acl.edit", + "createApps", + "editApps", + "useApps", + "createDatasets", + "editDatasets", + "manageDatasets", + ], + "editor": [ + "createApps", + "editApps", + "useApps", + "createDatasets", + "editDatasets", + "workspace.member.manage", + "app.acl.edit", + "app.acl.test_and_run", + "dataset.acl.edit", + ], + "normal": [ + "useApps", + "app.acl.test_and_run", + ], + "dataset_operator": [ + "manageDatasets", + "dataset.acl.edit", + ], +} + + def _current_ids() -> tuple[str, str]: """Return ``(tenant_id, account_id)`` for the authenticated user, or raise a 404 when no tenant is associated with the session. @@ -60,6 +143,49 @@ def _pagination_options() -> svc.ListOption: return _PaginationQuery.model_validate(request.args.to_dict(flat=True)).to_inner_options() +def _legacy_workspace_roles(options: svc.ListOption | None = None) -> svc.Paginated[svc.RBACRole]: + """Return the built-in legacy workspace roles in the RBAC list shape. + + This keeps the new `/rbac/roles` endpoint compatible with the original + Dify role model when enterprise RBAC is disabled. + """ + + legacy_roles = [ + svc.RBACRole( + id=role_name, + tenant_id="", + type=svc.RBACRoleType.WORKSPACE.value, + category="global_system_default", + name=role_name, + description="", + is_builtin=True, + permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]), + ) + for role_name in ("owner", "admin", "editor", "normal", "dataset_operator") + ] + + page_number = options.page_number if options and options.page_number is not None else 1 + results_per_page = options.results_per_page if options and options.results_per_page is not None else len(legacy_roles) + reverse = options.reverse if options and options.reverse is not None else False + + ordered_roles = list(reversed(legacy_roles)) if reverse else legacy_roles + start = max(page_number - 1, 0) * results_per_page + end = start + results_per_page + paged_roles = ordered_roles[start:end] + total_count = len(legacy_roles) + total_pages = (total_count + results_per_page - 1) // results_per_page if results_per_page > 0 else 0 + + return svc.Paginated[svc.RBACRole]( + data=paged_roles, + pagination=svc.Pagination( + total_count=total_count, + per_page=results_per_page, + current_page=page_number, + total_pages=total_pages, + ), + ) + + # --------------------------------------------------------------------------- # Permission catalogs. # --------------------------------------------------------------------------- @@ -115,6 +241,8 @@ class RBACRolesApi(Resource): def get(self): tenant_id, account_id = _current_ids() options = _pagination_options() + if not dify_config.RBAC_ENABLED: + return _dump(_legacy_workspace_roles(options)) return _dump(svc.RBACService.Roles.list(tenant_id, account_id, options=options)) @login_required diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 67b320beaa..ce8f27a28e 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from pydantic import computed_field, field_validator +from pydantic import Field, computed_field, field_validator from fields.base import ResponseModel from graphon.file import helpers as file_helpers @@ -70,6 +70,7 @@ class AccountWithRole(_AccountAvatar): last_active_at: int | None = None created_at: int | None = None role: str + roles: list[str] = Field(default_factory=list) status: str @field_validator("last_login_at", "last_active_at", "created_at", mode="before") diff --git a/api/services/enterprise/rbac_service.py b/api/services/enterprise/rbac_service.py index ae9a89b379..aa317b6a87 100644 --- a/api/services/enterprise/rbac_service.py +++ b/api/services/enterprise/rbac_service.py @@ -149,6 +149,10 @@ class MemberRolesResponse(_RBACModel): roles: list[RBACRole] = Field(default_factory=list) +class MemberRolesBatchResponse(_RBACModel): + data: list[MemberRolesResponse] = Field(default_factory=list) + + class ResourcePermissionKeys(_RBACModel): resource_id: str permission_keys: list[str] = Field(default_factory=list) @@ -908,6 +912,25 @@ class RBACService: ) return MemberRolesResponse.model_validate(data or {}) + @staticmethod + def batch_get( + tenant_id: str, + account_id: str | None, + member_account_ids: list[str], + ) -> list[MemberRolesResponse]: + data = _inner_call( + "POST", + f"{_INNER_PREFIX}/members/rbac-roles/batch", + tenant_id=tenant_id, + account_id=account_id, + json={"account_ids": member_account_ids}, + ) + if isinstance(data, list): + items = data + else: + items = (data or {}).get("data") or [] + return [MemberRolesResponse.model_validate(item) for item in items] + @staticmethod def replace( tenant_id: str, diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index 718b57ba6b..6c8fea1e50 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -45,8 +46,8 @@ class TestMemberListApi: member.name = "Member" member.email = "member@test.com" member.avatar = "avatar.png" - member.role = "admin" - member.status = "active" + member.current_role = SimpleNamespace(value="admin") + member.status = SimpleNamespace(value="active") members = [member] with ( @@ -58,6 +59,47 @@ class TestMemberListApi: assert status == 200 assert len(result["accounts"]) == 1 + assert result["accounts"][0]["role"] == "admin" + assert result["accounts"][0]["roles"] == ["admin"] + + def test_get_with_rbac_enabled_fetches_roles_in_batch(self, app): + api = MemberListApi() + method = unwrap(api.get) + + tenant = MagicMock(id="tenant-1") + user = MagicMock(id="acct-1", current_tenant=tenant) + member = SimpleNamespace( + id="m1", + name="Member", + email="member@test.com", + avatar=None, + last_login_at=1, + last_active_at=2, + created_at=3, + current_role=SimpleNamespace(value="editor"), + status=SimpleNamespace(value="active"), + ) + role_item = SimpleNamespace( + account_id="m1", + roles=[SimpleNamespace(id="workspace.owner"), SimpleNamespace(id="workspace.editor")], + ) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "tenant-1")), + patch("controllers.console.workspace.members.dify_config.RBAC_ENABLED", True), + patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=[member]), + patch( + "controllers.console.workspace.members.enterprise_rbac_service.RBACService.MemberRoles.batch_get", + return_value=[role_item], + ) as mock_batch_get, + ): + result, status = method(api) + + assert status == 200 + assert result["accounts"][0]["role"] == "editor" + assert result["accounts"][0]["roles"] == ["workspace.owner", "workspace.editor"] + mock_batch_get.assert_called_once_with("tenant-1", "acct-1", ["m1"]) def test_get_no_tenant(self, app): api = MemberListApi() diff --git a/api/tests/unit_tests/controllers/console/workspace/test_rbac.py b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py index 9cba20ba2b..561b533e25 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_rbac.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py @@ -37,6 +37,7 @@ def app(): def _enabled(enabled: bool): return patch("controllers.console.workspace.rbac.dify_config.ENTERPRISE_ENABLED", enabled) + class TestCurrentIds: def test_rejects_missing_tenant(self): with patch("controllers.console.workspace.rbac.current_account_with_tenant") as mock_user: @@ -117,10 +118,87 @@ class TestPydanticModels: class TestPaginationMapping: + def test_roles_get_returns_legacy_compatible_roles_when_rbac_disabled(self, app): + with ( + app.test_request_context("/workspaces/current/rbac/roles?page=1&limit=2"), + patch("controllers.console.workspace.rbac.dify_config.RBAC_ENABLED", False), + patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")), + patch("controllers.console.workspace.rbac.svc.RBACService.Roles.list") as mock_list, + ): + response = inspect.unwrap(rbac_mod.RBACRolesApi.get)(rbac_mod.RBACRolesApi()) + + assert response["data"] == [ + { + "id": "owner", + "tenant_id": "", + "type": "workspace", + "category": "global_system_default", + "name": "owner", + "description": "", + "is_builtin": True, + "permission_keys": [ + "inviteMembers", + "removeMembers", + "assignRoles", + "workspaceSettings", + "manageBilling", + "transferOwnership", + "createApps", + "editApps", + "useApps", + "createDatasets", + "editDatasets", + "manageDatasets", + "workspace.member.manage", + "workspace.settings.manage", + "workspace.billing.manage", + "workspace.owner.transfer", + "app.acl.edit", + "app.acl.test_and_run", + "dataset.acl.edit", + ], + }, + { + "id": "admin", + "tenant_id": "", + "type": "workspace", + "category": "global_system_default", + "name": "admin", + "description": "", + "is_builtin": True, + "permission_keys": [ + "inviteMembers", + "removeMembers", + "assignRoles", + "workspaceSettings", + "manageBilling", + "workspace.member.manage", + "workspace.settings.manage", + "workspace.billing.manage", + "app.acl.edit", + "app.acl.test_and_run", + "dataset.acl.edit", + "createApps", + "editApps", + "useApps", + "createDatasets", + "editDatasets", + "manageDatasets", + ], + }, + ] + assert response["pagination"] == { + "total_count": 5, + "per_page": 2, + "current_page": 1, + "total_pages": 3, + } + mock_list.assert_not_called() + def test_roles_get_forwards_outer_pagination_params(self, app): with ( app.test_request_context("/workspaces/current/rbac/roles?page=2&limit=50&reverse=true"), - _enabled(True), + patch("controllers.console.workspace.rbac.dify_config.RBAC_ENABLED", True), patch("controllers.console.workspace.rbac._current_ids", return_value=("tenant-1", "acct-1")), patch("controllers.console.workspace.rbac.svc.RBACService.Roles.list") as mock_list, patch("controllers.console.workspace.rbac._dump", return_value={}), diff --git a/api/tests/unit_tests/services/enterprise/test_rbac_service.py b/api/tests/unit_tests/services/enterprise/test_rbac_service.py index 90839a32d0..079bc7d73d 100644 --- a/api/tests/unit_tests/services/enterprise/test_rbac_service.py +++ b/api/tests/unit_tests/services/enterprise/test_rbac_service.py @@ -280,8 +280,8 @@ class TestWorkspaceAccess: out = svc.RBACService.WorkspaceAccess.app_matrix("tenant-1") - assert out.items[0].roles == [] - assert out.items[0].accounts == [] + assert out.items[0].role_ids == [] + assert out.items[0].account_ids == [] def test_workspace_app_replace_bindings(self, mock_send: MagicMock): mock_send.return_value = {"data": []} @@ -372,6 +372,32 @@ class TestMemberRoles: assert call.params == {"account_id": "acct-2"} assert call.json == {"role_ids": ["workspace.owner", "workspace.editor"]} + def test_batch_get(self, mock_send: MagicMock): + mock_send.return_value = { + "data": [ + { + "account_id": "acct-2", + "roles": [ + {"id": "role-1", "type": "workspace", "name": "Admin"}, + {"id": "role-2", "type": "workspace", "name": "Editor"}, + ], + }, + { + "account_id": "acct-3", + "roles": [], + }, + ] + } + + out = svc.RBACService.MemberRoles.batch_get("tenant-1", "acct-1", ["acct-2", "acct-3"]) + + call = _call_args(mock_send) + assert call.method == "POST" + assert call.endpoint == "/rbac/members/rbac-roles/batch" + assert call.json == {"account_ids": ["acct-2", "acct-3"]} + assert out[0].account_id == "acct-2" + assert len(out[0].roles) == 2 + class TestListOption: def test_empty_produces_empty_params(self):