diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 4f04af7932..bd5862cbd0 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -104,14 +104,11 @@ class BaseApiKeyResource(Resource): resource_model: type | None = None resource_id_field: str | None = None - def delete(self, resource_id, api_key_id): + def delete(self, resource_id: str, api_key_id: str): assert self.resource_id_field is not None, "resource_id_field must be set" - resource_id = str(resource_id) - api_key_id = str(api_key_id) current_user, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) - # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index defe82b8ae..a487512961 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -3,7 +3,7 @@ import uuid from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import BadRequest, Forbidden, abort +from werkzeug.exceptions import BadRequest, abort from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model @@ -12,6 +12,7 @@ from controllers.console.wraps import ( cloud_edition_billing_resource_check, edit_permission_required, enterprise_license_required, + is_admin_or_owner_required, setup_required, ) from core.ops.ops_trace_manager import OpsTraceManager @@ -485,15 +486,11 @@ class AppApiStatus(Resource): @api.response(403, "Insufficient permissions") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required @get_app_model @marshal_with(app_detail_fields) def post(self, app_model): - # The role of the current user in the ta table must be admin or owner - current_user, _ = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() - parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 72ce8a7ddf..91e2cfd60e 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -3,11 +3,10 @@ from typing import cast from flask import request from flask_restx import Resource, fields -from werkzeug.exceptions import Forbidden from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager @@ -48,15 +47,12 @@ class ModelConfigResource(Resource): @api.response(404, "App not found") @setup_required @login_required + @edit_permission_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): """Modify app model config""" current_user, current_tenant_id = current_account_with_tenant() - - if not current_user.has_edit_permission: - raise Forbidden() - # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_tenant_id, diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index c4d640bf0e..b8edbf77c7 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,10 +1,15 @@ from flask_restx import Resource, fields, marshal_with, reqparse -from werkzeug.exceptions import Forbidden, NotFound +from werkzeug.exceptions import NotFound from constants.languages import supported_language from controllers.console import api, console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + is_admin_or_owner_required, + setup_required, +) from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now @@ -76,17 +81,13 @@ class AppSite(Resource): @api.response(404, "App not found") @setup_required @login_required + @edit_permission_required @account_initialization_required @get_app_model @marshal_with(app_site_fields) def post(self, app_model): args = parse_app_site_args() current_user, _ = current_account_with_tenant() - - # The role of the current user in the ta table must be editor, admin, or owner - if not current_user.has_edit_permission: - raise Forbidden() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise NotFound @@ -130,16 +131,12 @@ class AppSiteAccessTokenReset(Resource): @api.response(404, "App or site not found") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required @get_app_model @marshal_with(app_site_fields) def post(self, app_model): - # The role of the current user in the ta table must be admin or owner current_user, _ = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() - site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 31077e371b..2f6808f11d 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -983,8 +983,9 @@ class DraftWorkflowTriggerRunApi(Resource): Poll for trigger events and execute full workflow when event arrives """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("node_id", type=str, required=True, location="json", nullable=False) + parser = reqparse.RequestParser().add_argument( + "node_id", type=str, required=True, location="json", nullable=False + ) args = parser.parse_args() node_id = args["node_id"] workflow_service = WorkflowService() @@ -1136,8 +1137,9 @@ class DraftWorkflowTriggerRunAllApi(Resource): """ current_user, _ = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False) + parser = reqparse.RequestParser().add_argument( + "node_ids", type=list, required=True, location="json", nullable=False + ) args = parser.parse_args() node_ids = args["node_ids"] workflow_service = WorkflowService() diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 0722eb40d2..ca97d8520c 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,17 +1,18 @@ import logging -from typing import NoReturn +from collections.abc import Callable +from functools import wraps +from typing import NoReturn, ParamSpec, TypeVar from flask import Response from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden from controllers.console import api, console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, ) from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.file import helpers as file_helpers from core.variables.segment_group import SegmentGroup @@ -21,8 +22,8 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from libs.login import current_user, login_required -from models import Account, App, AppMode +from libs.login import login_required +from models import App, AppMode from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_service import WorkflowService @@ -140,8 +141,11 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = { "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), } +P = ParamSpec("P") +R = TypeVar("R") -def _api_prerequisite(f): + +def _api_prerequisite(f: Callable[P, R]): """Common prerequisites for all draft workflow variable APIs. It ensures the following conditions are satisfied: @@ -155,11 +159,10 @@ def _api_prerequisite(f): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def wrapper(*args, **kwargs): - assert isinstance(current_user, Account) - if not current_user.has_edit_permission: - raise Forbidden() + @wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs): return f(*args, **kwargs) return wrapper @@ -167,6 +170,7 @@ def _api_prerequisite(f): @console_ns.route("/apps//workflows/draft/variables") class WorkflowVariableCollectionApi(Resource): + @api.expect(_create_pagination_parser()) @api.doc("get_workflow_variables") @api.doc(description="Get draft workflow variables") @api.doc(params={"app_id": "Application ID"}) diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index fd64261525..785813c5f0 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -3,12 +3,12 @@ import logging from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, NotFound +from werkzeug.exceptions import NotFound from configs import dify_config from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from extensions.ext_database import db from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields from libs.login import current_user, login_required @@ -29,8 +29,7 @@ class WebhookTriggerApi(Resource): @marshal_with(webhook_trigger_fields) def get(self, app_model: App): """Get webhook trigger for a node""" - parser = reqparse.RequestParser() - parser.add_argument("node_id", type=str, required=True, help="Node ID is required") + parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required") args = parser.parse_args() node_id = str(args["node_id"]) @@ -95,19 +94,19 @@ class AppTriggerEnableApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) @marshal_with(trigger_fields) def post(self, app_model: App): """Update app trigger (enable/disable)""" - parser = reqparse.RequestParser() - parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json") - parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("trigger_id", type=str, required=True, nullable=False, location="json") + .add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json") + ) args = parser.parse_args() - assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None - if not current_user.has_edit_permission: - raise Forbidden() trigger_id = args["trigger_id"] diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index a06435267b..9d7fcef183 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,8 +1,8 @@ from flask_restx import Resource, reqparse -from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.auth.error import ApiKeyAuthFailedError +from controllers.console.wraps import is_admin_or_owner_required from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -39,12 +39,10 @@ class ApiKeyAuthDataSourceBinding(Resource): @setup_required @login_required @account_initialization_required + @is_admin_or_owner_required def post(self): # The role of the current user in the table must be admin or owner - current_user, current_tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() parser = ( reqparse.RequestParser() .add_argument("category", type=str, required=True, nullable=False, location="json") @@ -65,12 +63,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @setup_required @login_required @account_initialization_required + @is_admin_or_owner_required def delete(self, binding_id): # The role of the current user in the table must be admin or owner - current_user, current_tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id) diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0fd433d718..a27932ccd8 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -3,11 +3,11 @@ import logging import httpx from flask import current_app, redirect, request from flask_restx import Resource, fields -from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api, console_ns -from libs.login import current_account_with_tenant, login_required +from controllers.console.wraps import is_admin_or_owner_required +from libs.login import login_required from libs.oauth_data_source import NotionOAuth from ..wraps import account_initialization_required, setup_required @@ -42,11 +42,9 @@ class OAuthDataSource(Resource): ) @api.response(400, "Invalid provider") @api.response(403, "Admin privileges required") + @is_admin_or_owner_required def get(self, provider: str): # The role of the current user in the table must be admin or owner - current_user, _ = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 50bf48450c..3aac571300 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -15,6 +15,7 @@ from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_rate_limit_check, enterprise_license_required, + is_admin_or_owner_required, setup_required, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -753,13 +754,11 @@ class DatasetApiKeyApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required @marshal_with(api_key_fields) def post(self): - # The role of the current user in the ta table must be admin or owner - current_user, current_tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() current_key_count = ( db.session.query(ApiToken) @@ -794,15 +793,11 @@ class DatasetApiDeleteApi(Resource): @api.response(204, "API key deleted successfully") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def delete(self, api_key_id): - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() api_key_id = str(api_key_id) - - # The role of the current user in the ta table must be admin or owner - if not current_user.is_admin_or_owner: - raise Forbidden() - key = ( db.session.query(ApiToken) .where( diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 4f738db0e5..fe96a8199a 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.console import api, console_ns from controllers.console.datasets.error import DatasetNameDuplicateError -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.dataset_fields import dataset_detail_fields from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService @@ -200,12 +200,10 @@ class ExternalDatasetCreateApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def post(self): # The role of the current user in the ta table must be admin, owner, or editor current_user, current_tenant_id = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - parser = ( reqparse.RequestParser() .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index 2c28120e65..d658d65b71 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -1,11 +1,11 @@ from flask_restx import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, + edit_permission_required, setup_required, ) from extensions.ext_database import db @@ -21,12 +21,11 @@ class RagPipelineImportApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @marshal_with(pipeline_import_fields) def post(self): # Check user role first current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() parser = ( reqparse.RequestParser() @@ -71,12 +70,10 @@ class RagPipelineImportConfirmApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @marshal_with(pipeline_import_fields) def post(self, import_id): current_user, _ = current_account_with_tenant() - # Check user role first - if not current_user.has_edit_permission: - raise Forbidden() # Create service with session with Session(db.engine) as session: @@ -98,12 +95,9 @@ class RagPipelineImportCheckDependenciesApi(Resource): @login_required @get_rag_pipeline @account_initialization_required + @edit_permission_required @marshal_with(pipeline_import_check_dependencies_fields) def get(self, pipeline: Pipeline): - current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - with Session(db.engine) as session: import_service = RagPipelineDslService(session) result = import_service.check_dependencies(pipeline=pipeline) @@ -117,12 +111,9 @@ class RagPipelineExportApi(Resource): @login_required @get_rag_pipeline @account_initialization_required + @edit_permission_required def get(self, pipeline: Pipeline): - current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - - # Add include_secret params + # Add include_secret params parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") args = parser.parse_args() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 1e77a988bd..bc8d4fbf81 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -191,6 +191,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline def post(self, pipeline: Pipeline, node_id: str): """ @@ -198,8 +199,6 @@ class RagPipelineDraftRunLoopNodeApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_run.parse_args() @@ -235,6 +234,7 @@ class DraftRagPipelineRunApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline def post(self, pipeline: Pipeline): """ @@ -242,8 +242,6 @@ class DraftRagPipelineRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_draft_run.parse_args() @@ -279,6 +277,7 @@ class PublishedRagPipelineRunApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline def post(self, pipeline: Pipeline): """ @@ -286,8 +285,6 @@ class PublishedRagPipelineRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_published_run.parse_args() @@ -404,6 +401,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline def post(self, pipeline: Pipeline, node_id: str): """ @@ -411,8 +409,6 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_rag_run.parse_args() @@ -444,6 +440,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): @api.expect(parser_rag_run) @setup_required @login_required + @edit_permission_required @account_initialization_required @get_rag_pipeline def post(self, pipeline: Pipeline, node_id: str): @@ -452,8 +449,6 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_rag_run.parse_args() @@ -490,6 +485,7 @@ class RagPipelineDraftNodeRunApi(Resource): @api.expect(parser_run_api) @setup_required @login_required + @edit_permission_required @account_initialization_required @get_rag_pipeline @marshal_with(workflow_run_node_execution_fields) @@ -499,8 +495,6 @@ class RagPipelineDraftNodeRunApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_run_api.parse_args() @@ -523,6 +517,7 @@ class RagPipelineDraftNodeRunApi(Resource): class RagPipelineTaskStopApi(Resource): @setup_required @login_required + @edit_permission_required @account_initialization_required @get_rag_pipeline def post(self, pipeline: Pipeline, task_id: str): @@ -531,8 +526,6 @@ class RagPipelineTaskStopApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) @@ -544,6 +537,7 @@ class PublishedRagPipelineApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline @marshal_with(workflow_fields) def get(self, pipeline: Pipeline): @@ -551,9 +545,6 @@ class PublishedRagPipelineApi(Resource): Get published pipeline """ # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() if not pipeline.is_published: return None # fetch published workflow by pipeline @@ -566,6 +557,7 @@ class PublishedRagPipelineApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline def post(self, pipeline: Pipeline): """ @@ -573,9 +565,6 @@ class PublishedRagPipelineApi(Resource): """ # The role of the current user in the ta table must be admin, owner, or editor current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - rag_pipeline_service = RagPipelineService() with Session(db.engine) as session: pipeline = session.merge(pipeline) @@ -602,16 +591,12 @@ class DefaultRagPipelineBlockConfigsApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline def get(self, pipeline: Pipeline): """ Get default block config """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - # Get default block configs rag_pipeline_service = RagPipelineService() return rag_pipeline_service.get_default_block_configs() @@ -626,16 +611,12 @@ class DefaultRagPipelineBlockConfigApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline def get(self, pipeline: Pipeline, block_type: str): """ Get default block config """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() - args = parser_default.parse_args() q = args.get("q") @@ -667,6 +648,7 @@ class PublishedAllRagPipelineApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline @marshal_with(workflow_pagination_fields) def get(self, pipeline: Pipeline): @@ -674,8 +656,6 @@ class PublishedAllRagPipelineApi(Resource): Get published workflows """ current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_wf.parse_args() page = args["page"] @@ -720,6 +700,7 @@ class RagPipelineByIdApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required @get_rag_pipeline @marshal_with(workflow_fields) def patch(self, pipeline: Pipeline, workflow_id: str): @@ -728,8 +709,6 @@ class RagPipelineByIdApi(Resource): """ # Check permission current_user, _ = current_account_with_tenant() - if not current_user.has_edit_permission: - raise Forbidden() args = parser_wf_id.parse_args() diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index ca8259238b..ee032756eb 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -3,7 +3,7 @@ from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api, console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from fields.tag_fields import dataset_tag_fields from libs.login import current_account_with_tenant, login_required from models.model import Tag @@ -91,12 +91,9 @@ class TagUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required + @edit_permission_required def delete(self, tag_id): - current_user, _ = current_account_with_tenant() tag_id = str(tag_id) - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.has_edit_permission: - raise Forbidden() TagService.delete_tag(tag_id) diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index d115f62d73..ae870a630e 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,8 +1,7 @@ from flask_restx import Resource, fields, reqparse -from werkzeug.exceptions import Forbidden from controllers.console import api, console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError from libs.login import current_account_with_tenant, login_required @@ -31,11 +30,10 @@ class EndpointCreateApi(Resource): @api.response(403, "Admin privileges required") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() parser = ( reqparse.RequestParser() @@ -168,6 +166,7 @@ class EndpointDeleteApi(Resource): @api.response(403, "Admin privileges required") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() @@ -175,9 +174,6 @@ class EndpointDeleteApi(Resource): parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True) args = parser.parse_args() - if not user.is_admin_or_owner: - raise Forbidden() - endpoint_id = args["endpoint_id"] return { @@ -207,6 +203,7 @@ class EndpointUpdateApi(Resource): @api.response(403, "Admin privileges required") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() @@ -223,9 +220,6 @@ class EndpointUpdateApi(Resource): settings = args["settings"] name = args["name"] - if not user.is_admin_or_owner: - raise Forbidden() - return { "success": EndpointService.update_endpoint( tenant_id=tenant_id, @@ -252,6 +246,7 @@ class EndpointEnableApi(Resource): @api.response(403, "Admin privileges required") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() @@ -261,9 +256,6 @@ class EndpointEnableApi(Resource): endpoint_id = args["endpoint_id"] - if not user.is_admin_or_owner: - raise Forbidden() - return { "success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) } @@ -284,6 +276,7 @@ class EndpointDisableApi(Resource): @api.response(403, "Admin privileges required") @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() @@ -293,9 +286,6 @@ class EndpointDisableApi(Resource): endpoint_id = args["endpoint_id"] - if not user.is_admin_or_owner: - raise Forbidden() - return { "success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) } diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 832ec8af0f..05731b3832 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -2,10 +2,9 @@ import io from flask import send_file from flask_restx import Resource, reqparse -from werkzeug.exceptions import Forbidden from controllers.console import api, console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder @@ -85,12 +84,10 @@ class ModelProviderCredentialApi(Resource): @api.expect(parser_post_cred) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() - + _, current_tenant_id = current_account_with_tenant() args = parser_post_cred.parse_args() model_provider_service = ModelProviderService() @@ -110,11 +107,10 @@ class ModelProviderCredentialApi(Resource): @api.expect(parser_put_cred) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def put(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() args = parser_put_cred.parse_args() @@ -136,12 +132,10 @@ class ModelProviderCredentialApi(Resource): @api.expect(parser_delete_cred) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def delete(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() - + _, current_tenant_id = current_account_with_tenant() args = parser_delete_cred.parse_args() model_provider_service = ModelProviderService() @@ -162,11 +156,10 @@ class ModelProviderCredentialSwitchApi(Resource): @api.expect(parser_switch) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() args = parser_switch.parse_args() service = ModelProviderService() @@ -250,11 +243,10 @@ class PreferredProviderTypeUpdateApi(Resource): @api.expect(parser_preferred) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() tenant_id = current_tenant_id diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d6aad129a6..79079f692e 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,10 +1,9 @@ import logging from flask_restx import Resource, reqparse -from werkzeug.exceptions import Forbidden from controllers.console import api, console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder @@ -50,12 +49,10 @@ class DefaultModelApi(Resource): @api.expect(parser_post_default) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): - current_user, tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() + _, tenant_id = current_account_with_tenant() args = parser_post_default.parse_args() model_provider_service = ModelProviderService() @@ -133,13 +130,11 @@ class ModelProviderModelApi(Resource): @api.expect(parser_post_models) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): # To save the model's load balance configs - current_user, tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() + _, tenant_id = current_account_with_tenant() args = parser_post_models.parse_args() if args.get("config_from", "") == "custom-model": @@ -181,12 +176,10 @@ class ModelProviderModelApi(Resource): @api.expect(parser_delete_models) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def delete(self, provider: str): - current_user, tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() + _, tenant_id = current_account_with_tenant() args = parser_delete_models.parse_args() @@ -314,12 +307,10 @@ class ModelProviderModelCredentialApi(Resource): @api.expect(parser_post_cred) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - current_user, tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() + _, tenant_id = current_account_with_tenant() args = parser_post_cred.parse_args() @@ -348,13 +339,10 @@ class ModelProviderModelCredentialApi(Resource): @api.expect(parser_put_cred) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def put(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() - + _, current_tenant_id = current_account_with_tenant() args = parser_put_cred.parse_args() model_provider_service = ModelProviderService() @@ -377,12 +365,10 @@ class ModelProviderModelCredentialApi(Resource): @api.expect(parser_delete_cred) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def delete(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() - - if not current_user.is_admin_or_owner: - raise Forbidden() + _, current_tenant_id = current_account_with_tenant() args = parser_delete_cred.parse_args() model_provider_service = ModelProviderService() @@ -417,12 +403,11 @@ class ModelProviderModelCredentialSwitchApi(Resource): @api.expect(parser_switch) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() + _, current_tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() args = parser_switch.parse_args() service = ModelProviderService() diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index bb8c02b99a..deae418e96 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api, console_ns from controllers.console.workspace import plugin_permission_required -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginDaemonClientSideError from libs.login import current_account_with_tenant, login_required @@ -132,9 +132,11 @@ class PluginAssetApi(Resource): @login_required @account_initialization_required def get(self): - req = reqparse.RequestParser() - req.add_argument("plugin_unique_identifier", type=str, required=True, location="args") - req.add_argument("file_name", type=str, required=True, location="args") + req = ( + reqparse.RequestParser() + .add_argument("plugin_unique_identifier", type=str, required=True, location="args") + .add_argument("file_name", type=str, required=True, location="args") + ) args = req.parse_args() _, tenant_id = current_account_with_tenant() @@ -619,13 +621,10 @@ class PluginFetchDynamicSelectOptionsApi(Resource): @api.expect(parser_dynamic) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def get(self): - # check if the user is admin or owner current_user, tenant_id = current_account_with_tenant() - if not current_user.is_admin_or_owner: - raise Forbidden() - user_id = current_user.id args = parser_dynamic.parse_args() @@ -770,9 +769,11 @@ class PluginReadmeApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - parser = reqparse.RequestParser() - parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args") - parser.add_argument("language", type=str, required=False, location="args") + parser = ( + reqparse.RequestParser() + .add_argument("plugin_unique_identifier", type=str, required=True, location="args") + .add_argument("language", type=str, required=False, location="args") + ) args = parser.parse_args() return jsonable_encoder( { diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index cc7fa0fc3d..917059bb4c 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -14,6 +14,7 @@ from controllers.console import api, console_ns from controllers.console.wraps import ( account_initialization_required, enterprise_license_required, + is_admin_or_owner_required, setup_required, ) from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -115,11 +116,10 @@ class ToolBuiltinProviderDeleteApi(Resource): @api.expect(parser_delete) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider): - user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() + _, tenant_id = current_account_with_tenant() args = parser_delete.parse_args() @@ -177,13 +177,10 @@ class ToolBuiltinProviderUpdateApi(Resource): @api.expect(parser_update) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider): user, tenant_id = current_account_with_tenant() - - if not user.is_admin_or_owner: - raise Forbidden() - user_id = user.id args = parser_update.parse_args() @@ -242,13 +239,11 @@ class ToolApiProviderAddApi(Resource): @api.expect(parser_api_add) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() - user_id = user.id args = parser_api_add.parse_args() @@ -336,13 +331,11 @@ class ToolApiProviderUpdateApi(Resource): @api.expect(parser_api_update) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() - user_id = user.id args = parser_api_update.parse_args() @@ -372,13 +365,11 @@ class ToolApiProviderDeleteApi(Resource): @api.expect(parser_api_delete) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() - user_id = user.id args = parser_api_delete.parse_args() @@ -496,13 +487,11 @@ class ToolWorkflowProviderCreateApi(Resource): @api.expect(parser_create) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() - user_id = user.id args = parser_create.parse_args() @@ -539,13 +528,10 @@ class ToolWorkflowProviderUpdateApi(Resource): @api.expect(parser_workflow_update) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() - - if not user.is_admin_or_owner: - raise Forbidden() - user_id = user.id args = parser_workflow_update.parse_args() @@ -577,13 +563,11 @@ class ToolWorkflowProviderDeleteApi(Resource): @api.expect(parser_workflow_delete) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self): user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() - user_id = user.id args = parser_workflow_delete.parse_args() @@ -734,18 +718,15 @@ class ToolLabelsApi(Resource): class ToolPluginOAuthApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def get(self, provider): tool_provider = ToolProviderID(provider) plugin_id = tool_provider.plugin_id provider_name = tool_provider.provider_name - # todo check permission user, tenant_id = current_account_with_tenant() - if not user.is_admin_or_owner: - raise Forbidden() - oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) if oauth_client_params is None: raise Forbidden("no oauth available client config found for this tool provider") @@ -856,14 +837,12 @@ class ToolOAuthCustomClient(Resource): @api.expect(parser_custom) @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required - def post(self, provider): + def post(self, provider: str): args = parser_custom.parse_args() - user, tenant_id = current_account_with_tenant() - - if not user.is_admin_or_owner: - raise Forbidden() + _, tenant_id = current_account_with_tenant() return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index bbbbe12fb0..b2abae0b3d 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config from controllers.console import api -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from controllers.web.error import NotFoundError from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType @@ -67,14 +67,12 @@ class TriggerProviderInfoApi(Resource): class TriggerSubscriptionListApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def get(self, provider): """List all trigger subscriptions for the current tenant's provider""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() try: return jsonable_encoder( @@ -92,17 +90,16 @@ class TriggerSubscriptionListApi(Resource): class TriggerSubscriptionBuilderCreateApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider): """Add a new subscription instance for a trigger provider""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json") + parser = reqparse.RequestParser().add_argument( + "credential_type", type=str, required=False, nullable=True, location="json" + ) args = parser.parse_args() try: @@ -133,18 +130,17 @@ class TriggerSubscriptionBuilderGetApi(Resource): class TriggerSubscriptionBuilderVerifyApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider, subscription_builder_id): """Verify a subscription instance for a trigger provider""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() - - parser = reqparse.RequestParser() - # The credentials of the subscription builder - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + # The credentials of the subscription builder + .add_argument("credentials", type=dict, required=False, nullable=True, location="json") + ) args = parser.parse_args() try: @@ -173,15 +169,17 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): assert isinstance(user, Account) assert user.current_tenant_id is not None - parser = reqparse.RequestParser() - # The name of the subscription builder - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - # The parameters of the subscription builder - parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json") - # The properties of the subscription builder - parser.add_argument("properties", type=dict, required=False, nullable=True, location="json") - # The credentials of the subscription builder - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + # The name of the subscription builder + .add_argument("name", type=str, required=False, nullable=True, location="json") + # The parameters of the subscription builder + .add_argument("parameters", type=dict, required=False, nullable=True, location="json") + # The properties of the subscription builder + .add_argument("properties", type=dict, required=False, nullable=True, location="json") + # The credentials of the subscription builder + .add_argument("credentials", type=dict, required=False, nullable=True, location="json") + ) args = parser.parse_args() try: return jsonable_encoder( @@ -223,24 +221,23 @@ class TriggerSubscriptionBuilderLogsApi(Resource): class TriggerSubscriptionBuilderBuildApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider, subscription_builder_id): """Build a subscription instance for a trigger provider""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() - - parser = reqparse.RequestParser() - # The name of the subscription builder - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - # The parameters of the subscription builder - parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json") - # The properties of the subscription builder - parser.add_argument("properties", type=dict, required=False, nullable=True, location="json") - # The credentials of the subscription builder - parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + # The name of the subscription builder + .add_argument("name", type=str, required=False, nullable=True, location="json") + # The parameters of the subscription builder + .add_argument("parameters", type=dict, required=False, nullable=True, location="json") + # The properties of the subscription builder + .add_argument("properties", type=dict, required=False, nullable=True, location="json") + # The credentials of the subscription builder + .add_argument("credentials", type=dict, required=False, nullable=True, location="json") + ) args = parser.parse_args() try: # Use atomic update_and_build to prevent race conditions @@ -264,14 +261,12 @@ class TriggerSubscriptionBuilderBuildApi(Resource): class TriggerSubscriptionDeleteApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, subscription_id: str): """Delete a subscription instance""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() try: with Session(db.engine) as session: @@ -446,14 +441,12 @@ class TriggerOAuthCallbackApi(Resource): class TriggerOAuthClientManageApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def get(self, provider): """Get OAuth client configuration for a provider""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() try: provider_id = TriggerProviderID(provider) @@ -493,18 +486,18 @@ class TriggerOAuthClientManageApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def post(self, provider): """Configure custom OAuth client for a provider""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") - parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json") + parser = ( + reqparse.RequestParser() + .add_argument("client_params", type=dict, required=False, nullable=True, location="json") + .add_argument("enabled", type=bool, required=False, nullable=True, location="json") + ) args = parser.parse_args() try: @@ -524,14 +517,12 @@ class TriggerOAuthClientManageApi(Resource): @setup_required @login_required + @is_admin_or_owner_required @account_initialization_required def delete(self, provider): """Remove custom OAuth client configuration""" user = current_user - assert isinstance(user, Account) assert user.current_tenant_id is not None - if not user.is_admin_or_owner: - raise Forbidden() try: provider_id = TriggerProviderID(provider) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 9b485544db..f40f566a36 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -315,3 +315,19 @@ def edit_permission_required(f: Callable[P, R]): return f(*args, **kwargs) return decorated_function + + +def is_admin_or_owner_required(f: Callable[P, R]): + @wraps(f) + def decorated_function(*args: P.args, **kwargs: P.kwargs): + from werkzeug.exceptions import Forbidden + + from libs.login import current_user + from models import Account + + user = current_user._get_current_object() + if not isinstance(user, Account) or not user.is_admin_or_owner: + raise Forbidden() + return f(*args, **kwargs) + + return decorated_function diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index ed013b1674..f26718555a 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -3,14 +3,12 @@ from typing import Literal from flask import request from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx.api import HTTPStatus -from werkzeug.exceptions import Forbidden +from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import annotation_fields, build_annotation_model -from libs.login import current_user -from models import Account from models.model import App from services.annotation_service import AppAnnotationService @@ -161,14 +159,10 @@ class AnnotationUpdateDeleteApi(Resource): } ) @validate_app_token + @edit_permission_required @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) - def put(self, app_model: App, annotation_id): + def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" - assert isinstance(current_user, Account) - if not current_user.has_edit_permission: - raise Forbidden() - - annotation_id = str(annotation_id) args = annotation_create_parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation @@ -185,13 +179,8 @@ class AnnotationUpdateDeleteApi(Resource): } ) @validate_app_token - def delete(self, app_model: App, annotation_id): + @edit_permission_required + def delete(self, app_model: App, annotation_id: str): """Delete an annotation.""" - assert isinstance(current_user, Account) - - if not current_user.has_edit_permission: - raise Forbidden() - - annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) return {"result": "success"}, 204 diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 9d5566919b..4cca3e6ce8 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -5,6 +5,7 @@ from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( @@ -619,11 +620,9 @@ class DatasetTagsApi(DatasetApiResource): } ) @validate_dataset_token + @edit_permission_required def delete(self, _, dataset_id): """Delete a knowledge type tag.""" - assert isinstance(current_user, Account) - if not current_user.has_edit_permission: - raise Forbidden() args = tag_delete_parser.parse_args() TagService.delete_tag(args["tag_id"])