From 8912420bff95b8c5caf83d1c608cb08cc76944ab Mon Sep 17 00:00:00 2001 From: fatelei Date: Fri, 8 May 2026 14:00:28 +0800 Subject: [PATCH] refactor: refactor rbac api --- api/controllers/console/workspace/rbac.py | 100 +++++------------- api/services/enterprise/rbac_service.py | 75 +++++++++++++ .../console/workspace/test_rbac.py | 5 +- .../services/enterprise/test_rbac_service.py | 44 +++++--- 4 files changed, 130 insertions(+), 94 deletions(-) diff --git a/api/controllers/console/workspace/rbac.py b/api/controllers/console/workspace/rbac.py index e1c724c735..239d02d89a 100644 --- a/api/controllers/console/workspace/rbac.py +++ b/api/controllers/console/workspace/rbac.py @@ -276,12 +276,8 @@ class RBACAccessPolicyCopyApi(Resource): # --------------------------------------------------------------------------- -class _ReplaceRoleBindingsRequest(BaseModel): +class _ReplaceBindingsRequest(BaseModel): role_ids: list[str] = [] - - - -class _ReplaceMemberBindingsRequest(BaseModel): account_ids: list[str] = [] @@ -320,21 +316,6 @@ class RBACAppRoleBindingsApi(Resource): svc.RBACService.AppAccess.list_role_bindings(tenant_id, account_id, str(app_id), str(policy_id)) ) - @enterprise_only - @login_required - def put(self, app_id, policy_id): - tenant_id, account_id = _current_ids() - request = _payload(_ReplaceRoleBindingsRequest) - return _dump( - svc.RBACService.AppAccess.replace_role_bindings( - tenant_id, - account_id, - str(app_id), - str(policy_id), - svc.ReplaceRoleBindings(role_ids=list(request.role_ids)), - ) - ) - @console_ns.route("/workspaces/current/rbac/apps//access-policies//member-bindings") class RBACAppMemberBindingsApi(Resource): @@ -346,18 +327,21 @@ class RBACAppMemberBindingsApi(Resource): svc.RBACService.AppAccess.list_member_bindings(tenant_id, account_id, str(app_id), str(policy_id)) ) + +@console_ns.route("/workspaces/current/rbac/apps//access-policies//bindings") +class RBACAppBindingsApi(Resource): @enterprise_only @login_required def put(self, app_id, policy_id): tenant_id, account_id = _current_ids() - request = _payload(_ReplaceMemberBindingsRequest) + request = _payload(_ReplaceBindingsRequest) return _dump( - svc.RBACService.AppAccess.replace_member_bindings( + svc.RBACService.AppAccess.replace_bindings( tenant_id, account_id, str(app_id), str(policy_id), - svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), + svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)), ) ) @@ -388,18 +372,21 @@ class RBACDatasetRoleBindingsApi(Resource): ) ) + +@console_ns.route("/workspaces/current/rbac/datasets//access-policies//bindings") +class RBACDatasetBindingsApi(Resource): @enterprise_only @login_required def put(self, dataset_id, policy_id): tenant_id, account_id = _current_ids() - request = _payload(_ReplaceRoleBindingsRequest) + request = _payload(_ReplaceBindingsRequest) return _dump( - svc.RBACService.DatasetAccess.replace_role_bindings( + svc.RBACService.DatasetAccess.replace_bindings( tenant_id, account_id, str(dataset_id), str(policy_id), - svc.ReplaceRoleBindings(role_ids=list(request.role_ids)), + svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)), ) ) @@ -418,21 +405,6 @@ class RBACDatasetMemberBindingsApi(Resource): ) ) - @enterprise_only - @login_required - def put(self, dataset_id, policy_id): - tenant_id, account_id = _current_ids() - request = _payload(_ReplaceMemberBindingsRequest) - return _dump( - svc.RBACService.DatasetAccess.replace_member_bindings( - tenant_id, - account_id, - str(dataset_id), - str(policy_id), - svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), - ) - ) - # --------------------------------------------------------------------------- # Workspace-level access (Settings > Access Rules). @@ -459,17 +431,20 @@ class RBACWorkspaceAppRoleBindingsApi(Resource): svc.RBACService.WorkspaceAccess.list_app_role_bindings(tenant_id, account_id, str(policy_id)) ) + +@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies//bindings") +class RBACWorkspaceAppBindingsApi(Resource): @enterprise_only @login_required def put(self, policy_id): tenant_id, account_id = _current_ids() - request = _payload(_ReplaceRoleBindingsRequest) + request = _payload(_ReplaceBindingsRequest) return _dump( - svc.RBACService.WorkspaceAccess.replace_app_role_bindings( + svc.RBACService.WorkspaceAccess.replace_app_bindings( tenant_id, account_id, str(policy_id), - svc.ReplaceRoleBindings(role_ids=list(request.role_ids)), + svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)), ) ) @@ -484,20 +459,6 @@ class RBACWorkspaceAppMemberBindingsApi(Resource): svc.RBACService.WorkspaceAccess.list_app_member_bindings(tenant_id, account_id, str(policy_id)) ) - @enterprise_only - @login_required - def put(self, policy_id): - tenant_id, account_id = _current_ids() - request = _payload(_ReplaceMemberBindingsRequest) - return _dump( - svc.RBACService.WorkspaceAccess.replace_app_member_bindings( - tenant_id, - account_id, - str(policy_id), - svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), - ) - ) - @console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy") class RBACWorkspaceDatasetMatrixApi(Resource): @@ -519,17 +480,20 @@ class RBACWorkspaceDatasetRoleBindingsApi(Resource): svc.RBACService.WorkspaceAccess.list_dataset_role_bindings(tenant_id, account_id, str(policy_id)) ) + +@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies//bindings") +class RBACWorkspaceDatasetBindingsApi(Resource): @enterprise_only @login_required def put(self, policy_id): tenant_id, account_id = _current_ids() - request = _payload(_ReplaceRoleBindingsRequest) + request = _payload(_ReplaceBindingsRequest) return _dump( - svc.RBACService.WorkspaceAccess.replace_dataset_role_bindings( + svc.RBACService.WorkspaceAccess.replace_dataset_bindings( tenant_id, account_id, str(policy_id), - svc.ReplaceRoleBindings(role_ids=list(request.role_ids)), + svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)), ) ) @@ -544,20 +508,6 @@ class RBACWorkspaceDatasetMemberBindingsApi(Resource): svc.RBACService.WorkspaceAccess.list_dataset_member_bindings(tenant_id, account_id, str(policy_id)) ) - @enterprise_only - @login_required - def put(self, policy_id): - tenant_id, account_id = _current_ids() - request = _payload(_ReplaceMemberBindingsRequest) - return _dump( - svc.RBACService.WorkspaceAccess.replace_dataset_member_bindings( - tenant_id, - account_id, - str(policy_id), - svc.ReplaceMemberBindings(account_ids=list(request.account_ids)), - ) - ) - # --------------------------------------------------------------------------- # Member ↔ role bindings (Settings > Members > Assign roles). diff --git a/api/services/enterprise/rbac_service.py b/api/services/enterprise/rbac_service.py index 8d25bd9d04..d48fe37cab 100644 --- a/api/services/enterprise/rbac_service.py +++ b/api/services/enterprise/rbac_service.py @@ -193,6 +193,11 @@ class ReplaceMemberBindings(_RBACModel): account_ids: list[str] = Field(default_factory=list) +class ReplaceBindings(_RBACModel): + role_ids: list[str] = Field(default_factory=list) + account_ids: list[str] = Field(default_factory=list) + + class ListOption(_RBACModel): page_number: int | None = None results_per_page: int | None = None @@ -526,6 +531,24 @@ class RBACService: ) return MemberBindingsResponse.model_validate(data or {}) + @staticmethod + def replace_bindings( + tenant_id: str, + account_id: str | None, + app_id: str, + policy_id: str, + payload: ReplaceBindings, + ) -> AccessMatrixItem: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/apps/access-policy/bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"app_id": app_id, "policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return AccessMatrixItem.model_validate(data or {}) + # ------------------------------------------------------------------ # Per-dataset access (screenshot 1: Knowledge Base Access Config). # ------------------------------------------------------------------ @@ -609,6 +632,24 @@ class RBACService: ) return MemberBindingsResponse.model_validate(data or {}) + @staticmethod + def replace_bindings( + tenant_id: str, + account_id: str | None, + dataset_id: str, + policy_id: str, + payload: ReplaceBindings, + ) -> AccessMatrixItem: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/datasets/access-policy/bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"dataset_id": dataset_id, "policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return AccessMatrixItem.model_validate(data or {}) + # ------------------------------------------------------------------ # Workspace-level access (screenshot 2: Settings > Access Rules). # ------------------------------------------------------------------ @@ -709,6 +750,23 @@ class RBACService: ) return MemberBindingsResponse.model_validate(data or {}) + @staticmethod + def replace_app_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + payload: ReplaceBindings, + ) -> AccessMatrixItem: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/workspace/apps/access-policy/bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return AccessMatrixItem.model_validate(data or {}) + @staticmethod def list_dataset_role_bindings( tenant_id: str, @@ -773,6 +831,23 @@ class RBACService: ) return MemberBindingsResponse.model_validate(data or {}) + @staticmethod + def replace_dataset_bindings( + tenant_id: str, + account_id: str | None, + policy_id: str, + payload: ReplaceBindings, + ) -> AccessMatrixItem: + data = _inner_call( + "PUT", + f"{_INNER_PREFIX}/workspace/datasets/access-policy/bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"policy_id": policy_id}, + json=payload.model_dump(mode="json"), + ) + return AccessMatrixItem.model_validate(data or {}) + # ------------------------------------------------------------------ # Member ↔ role bindings (screenshot 3: Settings > Members > Assign roles). # ------------------------------------------------------------------ diff --git a/api/tests/unit_tests/controllers/console/workspace/test_rbac.py b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py index fb786c130e..aa35ec7fc2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_rbac.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py @@ -110,9 +110,10 @@ class TestPydanticModels: with pytest.raises(ValidationError): rbac_mod._AccessPolicyCreateRequest.model_validate({"name": "bad", "resource_type": "unknown"}) - def test_replace_role_bindings_defaults_empty(self): - parsed = rbac_mod._ReplaceRoleBindingsRequest.model_validate({}) + def test_replace_bindings_defaults_empty(self): + parsed = rbac_mod._ReplaceBindingsRequest.model_validate({}) assert parsed.role_ids == [] + assert parsed.account_ids == [] def test_pagination_query_accepts_page_and_limit_aliases(self): parsed = rbac_mod._PaginationQuery.model_validate({"page": 3, "limit": 25, "reverse": True}) diff --git a/api/tests/unit_tests/services/enterprise/test_rbac_service.py b/api/tests/unit_tests/services/enterprise/test_rbac_service.py index 0dd3ff33b3..994942fb3d 100644 --- a/api/tests/unit_tests/services/enterprise/test_rbac_service.py +++ b/api/tests/unit_tests/services/enterprise/test_rbac_service.py @@ -191,27 +191,25 @@ class TestResourceAccess: assert call.params == {"app_id": "app-1"} assert out.app_id == "app-1" - def test_app_replace_role_bindings(self, mock_send: MagicMock): + def test_app_replace_bindings(self, mock_send: MagicMock): mock_send.return_value = {"data": []} - payload = svc.ReplaceRoleBindings(role_ids=["workspace.owner"]) - svc.RBACService.AppAccess.replace_role_bindings("tenant-1", "acct-1", "app-1", "policy-1", payload) + payload = svc.ReplaceBindings(role_ids=["workspace.owner"], account_ids=["acct-2"]) + svc.RBACService.AppAccess.replace_bindings("tenant-1", "acct-1", "app-1", "policy-1", payload) call = _call_args(mock_send) assert call.method == "PUT" - assert call.endpoint == "/rbac/apps/access-policy/role-bindings" + assert call.endpoint == "/rbac/apps/access-policy/bindings" assert call.params == {"app_id": "app-1", "policy_id": "policy-1"} - assert call.json == {"role_ids": ["workspace.owner"]} + assert call.json == {"role_ids": ["workspace.owner"], "account_ids": ["acct-2"]} - def test_dataset_replace_member_bindings(self, mock_send: MagicMock): + def test_dataset_replace_bindings(self, mock_send: MagicMock): mock_send.return_value = {"data": []} - payload = svc.ReplaceMemberBindings(account_ids=["acct-2"]) - svc.RBACService.DatasetAccess.replace_member_bindings( - "tenant-1", "acct-1", "ds-1", "policy-1", payload - ) + payload = svc.ReplaceBindings(role_ids=["workspace.editor"], account_ids=["acct-2"]) + svc.RBACService.DatasetAccess.replace_bindings("tenant-1", "acct-1", "ds-1", "policy-1", payload) call = _call_args(mock_send) assert call.method == "PUT" - assert call.endpoint == "/rbac/datasets/access-policy/member-bindings" + assert call.endpoint == "/rbac/datasets/access-policy/bindings" assert call.params == {"dataset_id": "ds-1", "policy_id": "policy-1"} - assert call.json == {"account_ids": ["acct-2"]} + assert call.json == {"role_ids": ["workspace.editor"], "account_ids": ["acct-2"]} class TestWorkspaceAccess: @@ -235,17 +233,29 @@ class TestWorkspaceAccess: assert call.endpoint == "/rbac/workspace/datasets/access-policy" assert call.params is None - def test_dataset_replace_role_bindings(self, mock_send: MagicMock): + def test_workspace_app_replace_bindings(self, mock_send: MagicMock): mock_send.return_value = {"data": []} - payload = svc.ReplaceRoleBindings(role_ids=["workspace.editor"]) - svc.RBACService.WorkspaceAccess.replace_dataset_role_bindings( + payload = svc.ReplaceBindings(role_ids=["workspace.editor"], account_ids=["acct-2"]) + svc.RBACService.WorkspaceAccess.replace_app_bindings( "tenant-1", "acct-1", "policy-1", payload ) call = _call_args(mock_send) assert call.method == "PUT" - assert call.endpoint == "/rbac/workspace/datasets/access-policy/role-bindings" + assert call.endpoint == "/rbac/workspace/apps/access-policy/bindings" assert call.params == {"policy_id": "policy-1"} - assert call.json == {"role_ids": ["workspace.editor"]} + assert call.json == {"role_ids": ["workspace.editor"], "account_ids": ["acct-2"]} + + def test_workspace_dataset_replace_bindings(self, mock_send: MagicMock): + mock_send.return_value = {"data": []} + payload = svc.ReplaceBindings(role_ids=["workspace.editor"], account_ids=["acct-2"]) + svc.RBACService.WorkspaceAccess.replace_dataset_bindings( + "tenant-1", "acct-1", "policy-1", payload + ) + call = _call_args(mock_send) + assert call.method == "PUT" + assert call.endpoint == "/rbac/workspace/datasets/access-policy/bindings" + assert call.params == {"policy_id": "policy-1"} + assert call.json == {"role_ids": ["workspace.editor"], "account_ids": ["acct-2"]} class TestMyPermissions: