From c657378d06cd646ccdd4dcb09ab047abe64fa971 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 28 Oct 2024 15:54:34 +0800 Subject: [PATCH] feat: support plugin permission management --- .../console/datasets/datasets_document.py | 129 +++++++------- api/controllers/console/workspace/__init__.py | 56 ++++++ api/controllers/console/workspace/plugin.py | 160 +++++++++--------- ...c4f75af5e_add_tenant_plugin_permisisons.py | 37 ++++ api/models/account.py | 26 +++ .../plugin/plugin_permission_service.py | 34 ++++ 6 files changed, 289 insertions(+), 153 deletions(-) create mode 100644 api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py create mode 100644 api/services/plugin/plugin_permission_service.py diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 5f681c238f..31b4f7b741 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -5,8 +5,7 @@ from datetime import datetime, timezone from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal, marshal_with, reqparse -from sqlalchemy import asc, desc, select -from sqlalchemy.orm import Session +from sqlalchemy import asc, desc from werkzeug.exceptions import Forbidden, NotFound import services @@ -105,8 +104,7 @@ class GetProcessRuleApi(Resource): rules = DocumentService.DEFAULT_RULES["rules"] if document_id: # get the latest process rule - with Session(db.engine) as session: - document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none() + document = Document.query.get_or_404(document_id) dataset = DatasetService.get_dataset(document.dataset_id) @@ -169,77 +167,66 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - with Session(db.engine) as session: - query = session.query(Document).filter_by( - dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + + if search: + search = f"%{search}%" + query = query.filter(Document.name.like(search)) + + if sort.startswith("-"): + sort_logic = desc + sort = sort[1:] + else: + sort_logic = asc + + if sort == "hit_count": + sub_query = ( + db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .group_by(DocumentSegment.document_id) + .subquery() ) - if search: - search = f"%{search}%" - query = query.filter(Document.name.like(search)) + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + elif sort == "created_at": + query = query.order_by( + sort_logic(Document.created_at), + sort_logic(Document.position), + ) + else: + query = query.order_by( + desc(Document.created_at), + desc(Document.position), + ) - if sort.startswith("-"): - sort_logic = desc - sort = sort[1:] - else: - sort_logic = asc + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + documents = paginated_documents.items + if fetch: + for document in documents: + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() + document.completed_segments = completed_segments + document.total_segments = total_segments + data = marshal(documents, document_with_segments_fields) + else: + data = marshal(documents, document_fields) + response = { + "data": data, + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, + } - if sort == "hit_count": - sub_query = ( - db.select( - DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count") - ) - .group_by(DocumentSegment.document_id) - .subquery() - ) - - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( - sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == "created_at": - query = query.order_by( - sort_logic(Document.created_at), - sort_logic(Document.position), - ) - else: - query = query.order_by( - desc(Document.created_at), - desc(Document.position), - ) - - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) - documents = paginated_documents.items - if fetch: - for document in documents: - completed_segments = ( - session.query(DocumentSegment) - .filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ) - .count() - ) - total_segments = ( - session.query(DocumentSegment) - .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") - .count() - ) - document.completed_segments = completed_segments - document.total_segments = total_segments - data = marshal(documents, document_with_segments_fields) - else: - data = marshal(documents, document_fields) - response = { - "data": data, - "has_more": len(documents) == limit, - "limit": limit, - "total": paginated_documents.total, - "page": page, - } - - return response + return response documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index e69de29bb2..072e904caf 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -0,0 +1,56 @@ +from functools import wraps + +from flask_login import current_user +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from extensions.ext_database import db +from models.account import TenantPluginPermission + + +def plugin_permission_required( + install_required: bool = False, + debug_required: bool = False, +): + def interceptor(view): + @wraps(view) + def decorated(*args, **kwargs): + user = current_user + tenant_id = user.current_tenant_id + + with Session(db.engine) as session: + permission = ( + session.query(TenantPluginPermission) + .filter( + TenantPluginPermission.tenant_id == tenant_id, + ) + .first() + ) + + if not permission: + # no permission set, allow access for everyone + return view(*args, **kwargs) + + if install_required: + if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: + raise Forbidden() + if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: + pass + + if debug_required: + if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: + raise Forbidden() + if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: + pass + + return view(*args, **kwargs) + + return decorated + + return interceptor diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index b3333f67e0..2daf3ae173 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -8,9 +8,12 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.setup import setup_required +from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required +from models.account import TenantPluginPermission +from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_service import PluginService @@ -18,12 +21,9 @@ class PluginDebuggingKeyApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def get(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id return { "key": PluginService.get_debugging_key(tenant_id), @@ -37,8 +37,7 @@ class PluginListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id plugins = PluginService.list(tenant_id) return jsonable_encoder({"plugins": plugins}) @@ -57,32 +56,13 @@ class PluginIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -class PluginUploadPkgApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id - file = request.files["pkg"] - content = file.read() - - return jsonable_encoder(PluginService.upload_pkg(tenant_id, content)) - - class PluginUploadFromPkgApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(install_required=True) def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id file = request.files["pkg"] @@ -100,12 +80,9 @@ class PluginUploadFromGithubApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(install_required=True) def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") @@ -124,12 +101,9 @@ class PluginInstallFromPkgApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(install_required=True) def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") @@ -149,12 +123,9 @@ class PluginInstallFromGithubApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(install_required=True) def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("repo", type=str, required=True, location="json") @@ -178,12 +149,9 @@ class PluginInstallFromMarketplaceApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(install_required=True) def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json") @@ -203,15 +171,14 @@ class PluginFetchManifestApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def get(self): - user = current_user + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") args = parser.parse_args() - tenant_id = user.current_tenant_id - return jsonable_encoder( {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args["plugin_unique_identifier"]).model_dump()} ) @@ -221,12 +188,9 @@ class PluginFetchInstallTasksApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def get(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -242,12 +206,9 @@ class PluginFetchInstallTaskApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def get(self, task_id: str): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) @@ -256,12 +217,9 @@ class PluginDeleteInstallTaskApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def post(self, task_id: str): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id return {"success": PluginService.delete_install_task(tenant_id, task_id)} @@ -270,12 +228,9 @@ class PluginDeleteInstallTaskItemApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def post(self, task_id: str, identifier: str): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} @@ -284,12 +239,9 @@ class PluginUpgradeFromMarketplaceApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") @@ -307,12 +259,9 @@ class PluginUpgradeFromGithubApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def post(self): - user = current_user - if not user.is_admin_or_owner: - raise Forbidden() - - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json") @@ -338,18 +287,62 @@ class PluginUninstallApi(Resource): @setup_required @login_required @account_initialization_required + @plugin_permission_required(debug_required=True) def post(self): req = reqparse.RequestParser() req.add_argument("plugin_installation_id", type=str, required=True, location="json") args = req.parse_args() + tenant_id = current_user.current_tenant_id + + return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} + + +class PluginChangePermissionApi(Resource): + @setup_required + @login_required + @account_initialization_required + @plugin_permission_required(debug_required=True) + def post(self): user = current_user if not user.is_admin_or_owner: raise Forbidden() + req = reqparse.RequestParser() + req.add_argument("install_permission", type=str, required=True, location="json") + req.add_argument("debug_permission", type=str, required=True, location="json") + args = req.parse_args() + + install_permission = TenantPluginPermission.InstallPermission(args["install_permission"]) + debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"]) + tenant_id = user.current_tenant_id - return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])} + return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)} + + +class PluginFetchPermissionApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id = current_user.current_tenant_id + + permission = PluginPermissionService.get_permission(tenant_id) + if not permission: + return jsonable_encoder( + { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + ) + + return jsonable_encoder( + { + "install_permission": permission.install_permission, + "debug_permission": permission.debug_permission, + } + ) api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") @@ -368,3 +361,6 @@ api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks//delete") api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks//delete/") api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") + +api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") +api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") diff --git a/api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py b/api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py new file mode 100644 index 0000000000..51a0b1b211 --- /dev/null +++ b/api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py @@ -0,0 +1,37 @@ +"""add_tenant_plugin_permisisons + +Revision ID: 08ec4f75af5e +Revises: ddcc8bbef391 +Create Date: 2024-10-28 07:20:39.711124 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '08ec4f75af5e' +down_revision = 'ddcc8bbef391' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('account_plugin_permissions', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('install_permission', sa.String(length=16), server_default='everyone', nullable=False), + sa.Column('debug_permission', sa.String(length=16), server_default='noone', nullable=False), + sa.PrimaryKeyConstraint('id', name='account_plugin_permission_pkey'), + sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin') + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('account_plugin_permissions') + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index ae87e22649..99464865dd 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -2,6 +2,7 @@ import enum import json from flask_login import UserMixin +from sqlalchemy.orm import Mapped, mapped_column from extensions.ext_database import db from models.base import Base @@ -260,3 +261,28 @@ class InvitationCode(db.Model): used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class TenantPluginPermission(Base): + class InstallPermission(str, enum.Enum): + EVERYONE = "everyone" + ADMINS = "admins" + NOBODY = "noone" + + class DebugPermission(str, enum.Enum): + EVERYONE = "everyone" + ADMINS = "admins" + NOBODY = "noone" + + __tablename__ = "account_plugin_permissions" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), + db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + install_permission: Mapped[InstallPermission] = mapped_column( + db.String(16), nullable=False, server_default="everyone" + ) + debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py new file mode 100644 index 0000000000..275e496037 --- /dev/null +++ b/api/services/plugin/plugin_permission_service.py @@ -0,0 +1,34 @@ +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.account import TenantPluginPermission + + +class PluginPermissionService: + @staticmethod + def get_permission(tenant_id: str) -> TenantPluginPermission | None: + with Session(db.engine) as session: + return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first() + + @staticmethod + def change_permission( + tenant_id: str, + install_permission: TenantPluginPermission.InstallPermission, + debug_permission: TenantPluginPermission.DebugPermission, + ): + with Session(db.engine) as session: + permission = ( + session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first() + ) + if not permission: + permission = TenantPluginPermission( + tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission + ) + + session.add(permission) + else: + permission.install_permission = install_permission + permission.debug_permission = debug_permission + + session.commit() + return True