diff --git a/api/controllers/console/snippets/snippet_workflow.py b/api/controllers/console/snippets/snippet_workflow.py index 5af885ab91b..444c8295eb0 100644 --- a/api/controllers/console/snippets/snippet_workflow.py +++ b/api/controllers/console/snippets/snippet_workflow.py @@ -157,7 +157,6 @@ class SnippetDraftWorkflowApi(Resource): @account_initialization_required @get_snippet @edit_permission_required - @rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_MANAGE, resource_required=False) def get(self, snippet: CustomizedSnippet): """Get draft workflow for snippet.""" snippet_service = _snippet_service() @@ -226,7 +225,6 @@ class SnippetDraftConfigApi(Resource): @account_initialization_required @get_snippet @edit_permission_required - @rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_MANAGE, resource_required=False) def get(self, snippet: CustomizedSnippet): """Get snippet draft workflow configuration limits.""" return { @@ -248,7 +246,6 @@ class SnippetPublishedWorkflowApi(Resource): @account_initialization_required @get_snippet @edit_permission_required - @rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_MANAGE, resource_required=False) def get(self, snippet: CustomizedSnippet): """Get published workflow for snippet.""" if not snippet.is_published: @@ -313,7 +310,6 @@ class SnippetDefaultBlockConfigsApi(Resource): @account_initialization_required @get_snippet @edit_permission_required - @rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_MANAGE, resource_required=False) def get(self, snippet: CustomizedSnippet): """Get default block configurations for snippet workflow.""" snippet_service = _snippet_service() @@ -336,7 +332,9 @@ class SnippetPublishedAllWorkflowApi(Resource): @account_initialization_required @get_snippet @edit_permission_required - @rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_MANAGE, resource_required=False) + @rbac_permission_required( + RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False + ) def get(self, snippet: CustomizedSnippet): """Get all published workflow versions for snippet.""" args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True)) @@ -503,9 +501,6 @@ class SnippetDraftNodeRunApi(Resource): @with_current_user @get_snippet @edit_permission_required - @rbac_permission_required( - RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False - ) def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str): """ Run a single node in snippet draft workflow. @@ -594,9 +589,6 @@ class SnippetDraftRunIterationNodeApi(Resource): @with_current_user @get_snippet @edit_permission_required - @rbac_permission_required( - RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False - ) def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str): """ Run a draft workflow iteration node for snippet. @@ -642,9 +634,6 @@ class SnippetDraftRunLoopNodeApi(Resource): @with_current_user @get_snippet @edit_permission_required - @rbac_permission_required( - RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False - ) def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str): """ Run a draft workflow loop node for snippet. @@ -688,9 +677,6 @@ class SnippetDraftWorkflowRunApi(Resource): @with_current_user @get_snippet @edit_permission_required - @rbac_permission_required( - RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False - ) def post(self, current_user: Account, snippet: CustomizedSnippet): """ Run draft workflow for snippet. @@ -729,9 +715,6 @@ class SnippetWorkflowTaskStopApi(Resource): @account_initialization_required @get_snippet @edit_permission_required - @rbac_permission_required( - RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False - ) def post(self, snippet: CustomizedSnippet, task_id: str): """ Stop a running snippet workflow task. diff --git a/api/controllers/console/snippets/snippet_workflow_draft_variable.py b/api/controllers/console/snippets/snippet_workflow_draft_variable.py index 4befd259666..d0f98ffa95b 100644 --- a/api/controllers/console/snippets/snippet_workflow_draft_variable.py +++ b/api/controllers/console/snippets/snippet_workflow_draft_variable.py @@ -105,7 +105,6 @@ class SnippetWorkflowVariableCollectionApi(Resource): ) @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_list_without_value_model) - @rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_MANAGE, resource_required=False) def get(self, current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -158,6 +157,9 @@ class SnippetNodeVariableCollectionApi(Resource): @console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)") @console_ns.response(204, "Node variables deleted successfully") @_snippet_draft_var_prerequisite + @rbac_permission_required( + RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False + ) def delete(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> Response: validate_node_id(node_id) srv = WorkflowDraftVariableService(db.session()) @@ -192,6 +194,9 @@ class SnippetVariableApi(Resource): @console_ns.response(404, "Variable not found") @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_model) + @rbac_permission_required( + RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False + ) def patch(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable: draft_var_srv = WorkflowDraftVariableService(session=db.session()) args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) @@ -239,6 +244,9 @@ class SnippetVariableApi(Resource): @console_ns.response(204, "Variable deleted successfully") @console_ns.response(404, "Variable not found") @_snippet_draft_var_prerequisite + @rbac_permission_required( + RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False + ) def delete(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response: draft_var_srv = WorkflowDraftVariableService(session=db.session()) variable = ensure_variable_access( @@ -261,6 +269,9 @@ class SnippetVariableResetApi(Resource): @console_ns.response(204, "Variable reset (no content)") @console_ns.response(404, "Variable not found") @_snippet_draft_var_prerequisite + @rbac_permission_required( + RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False + ) def put(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response | Any: draft_var_srv = WorkflowDraftVariableService(session=db.session()) snippet_service = _snippet_service() diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 38e7395ccf8..1af5b113f1d 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -4,12 +4,17 @@ from uuid import UUID from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from werkzeug.exceptions import Forbidden +from configs import dify_config from controllers.common.fields import SimpleResultResponse from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models +from controllers.common.wraps import enforce_rbac_access from controllers.console import console_ns from controllers.console.wraps import ( + RBACPermission, + RBACResourceScope, account_initialization_required, edit_permission_required, setup_required, @@ -18,9 +23,10 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from fields.base import ResponseModel -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import Account from models.enums import TagType +from models.model import Tag from services.tag_service import ( SaveTagPayload, TagBindingCreatePayload, @@ -91,6 +97,30 @@ register_schema_models( register_response_schema_models(console_ns, SimpleResultResponse) +def _enforce_snippet_tag_rbac_if_needed(tag_type: TagType | str | None) -> None: + if tag_type != TagType.SNIPPET: + return + if not dify_config.RBAC_ENABLED: + return + + current_user, current_tenant_id = current_account_with_tenant() + enforce_rbac_access( + tenant_id=current_tenant_id, + account_id=current_user.id, + resource_type=RBACResourceScope.WORKSPACE, + scene=RBACPermission.SNIPPETS_CREATE_AND_MODIFY, + resource_required=False, + ) + + +def _enforce_snippet_tag_rbac_by_tag_id(tag_id: str) -> None: + if not dify_config.RBAC_ENABLED: + return + + tag_type = db.session.scalar(select(Tag.type).where(Tag.id == tag_id).limit(1)) + _enforce_snippet_tag_rbac_if_needed(tag_type) + + @console_ns.route("/tags") class TagListApi(Resource): @setup_required @@ -122,6 +152,7 @@ class TagListApi(Resource): raise Forbidden() payload = TagBasePayload.model_validate(console_ns.payload or {}) + _enforce_snippet_tag_rbac_if_needed(payload.type) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type), db.session) response = TagResponse.model_validate( @@ -146,6 +177,7 @@ class TagUpdateDeleteApi(Resource): raise Forbidden() payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {}) + _enforce_snippet_tag_rbac_by_tag_id(tag_id_str) tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str, db.session) binding_count = TagService.get_tag_binding_count(tag_id_str, db.session) @@ -164,6 +196,7 @@ class TagUpdateDeleteApi(Resource): def delete(self, tag_id: UUID): tag_id_str = str(tag_id) + _enforce_snippet_tag_rbac_by_tag_id(tag_id_str) TagService.delete_tag(tag_id_str, db.session) return "", 204 @@ -184,6 +217,7 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: _require_tag_binding_edit_permission(current_user) payload = TagBindingPayload.model_validate(console_ns.payload or {}) + _enforce_snippet_tag_rbac_if_needed(payload.type) TagService.save_tag_binding( TagBindingCreatePayload( tag_ids=payload.tag_ids, @@ -199,6 +233,7 @@ def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: _require_tag_binding_edit_permission(current_user) payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + _enforce_snippet_tag_rbac_if_needed(payload.type) TagService.delete_tag_binding( TagBindingDeletePayload( tag_ids=payload.tag_ids, diff --git a/api/controllers/console/workspace/snippets.py b/api/controllers/console/workspace/snippets.py index e8f0b228c8b..e0987a3cae1 100644 --- a/api/controllers/console/workspace/snippets.py +++ b/api/controllers/console/workspace/snippets.py @@ -455,9 +455,6 @@ class CustomizedSnippetUseCountIncrementApi(Resource): @login_required @account_initialization_required @edit_permission_required - @rbac_permission_required( - RBACResourceScope.WORKSPACE, RBACPermission.SNIPPETS_CREATE_AND_MODIFY, resource_required=False - ) @with_current_tenant_id def post(self, current_tenant_id: str, snippet_id: str): """Increment snippet use count when it is inserted into a workflow.""" diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 84a70835437..8f5bb176b8e 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -155,6 +155,36 @@ class TestTagListApi: assert result["name"] == "test-tag" assert result["binding_count"] == "0" + def test_post_snippet_tag_checks_snippet_rbac_when_enabled(self, app: Flask, admin_user, tag, payload_patch): + api = TagListApi() + method = unwrap(api.post) + + payload = {"name": "snippet-tag", "type": "snippet"} + + with app.test_request_context("/", json=payload): + with ( + payload_patch(payload), + patch("controllers.console.tag.tags.dify_config.RBAC_ENABLED", True), + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(SimpleNamespace(id="user-1"), "tenant-1"), + ), + patch("controllers.console.tag.tags.enforce_rbac_access") as enforce_mock, + patch( + "controllers.console.tag.tags.TagService.save_tags", + return_value=tag, + ), + ): + method(api, admin_user) + + enforce_mock.assert_called_once_with( + tenant_id="tenant-1", + account_id="user-1", + resource_type=module.RBACResourceScope.WORKSPACE, + scene=module.RBACPermission.SNIPPETS_CREATE_AND_MODIFY, + resource_required=False, + ) + def test_post_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagListApi() method = unwrap(api.post)