mirror of https://github.com/langgenius/dify.git
refactor: use libs.login current_user in console controllers (#26745)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
parent
24cd7bbc62
commit
2f50f3fd4b
|
|
@ -1,5 +1,4 @@
|
|||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
|
|
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
|
|||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
|
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
|
|||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
|
|
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
|
|||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
|
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
|
|||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
|
|
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
|
|||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
|
|
@ -21,6 +19,7 @@ from core.errors.error import (
|
|||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
|
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
|
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
|
|||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@ from collections.abc import Callable
|
|||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.explore.error import AppAccessDeniedError
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import InstalledApp
|
||||
from models.account import Account
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
|
|
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
assert isinstance(current_user, Account)
|
||||
app_id = installed_app.app_id
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
|
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tenant_id = current_user.current_tenant_id
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
|
|
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
|
|
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api, console_ns
|
||||
|
|
@ -23,6 +23,8 @@ class FeatureApi(Resource):
|
|||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
|
|||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
|
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
|
|||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
|
@ -24,6 +24,8 @@ class TagListApi(Resource):
|
|||
@account_initialization_required
|
||||
@marshal_with(dataset_tag_fields)
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_type = request.args.get("type", type=str, default="")
|
||||
keyword = request.args.get("keyword", default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
|
@ -34,8 +36,10 @@ class TagListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
|
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
|
|
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
|
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
assert isinstance(current_user, Account)
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
|
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
|
|||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
def _current_account_with_tenant() -> tuple[Account, str]:
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
return current_user, tenant_id
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@api.doc("create_endpoint")
|
||||
|
|
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
|
|
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
|
|||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
|
|
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
|
|
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
|
|
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
|
|||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
|
|
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
|
|||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
|
|||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
|
|
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
user, tenant_id = _current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
|
|
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from urllib import parse
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
|
|
@ -26,7 +25,7 @@ from controllers.console.wraps import (
|
|||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
|
@ -24,7 +23,7 @@ from controllers.console.wraps import (
|
|||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from functools import wraps
|
|||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import AccountStatus
|
||||
from libs.login import current_user
|
||||
from models.account import Account, AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
|
@ -25,11 +25,16 @@ P = ParamSpec("P")
|
|||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _current_account() -> Account:
|
||||
assert isinstance(current_user, Account)
|
||||
return current_user
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
# check account initialization
|
||||
account = current_user
|
||||
account = _current_account()
|
||||
|
||||
if account.status == AccountStatus.UNINITIALIZED:
|
||||
raise AccountNotInitializedError()
|
||||
|
|
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
|||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
|
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
|
|
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
|||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
|
|
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
|
|
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
if request_count > knowledge_rate_limit.limit:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
|
|
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
|
|||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
tenant_id = account.current_tenant_id
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
||||
OperationService.record_utm(tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
|
|||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
|||
def knowledge_pipeline_publish_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
account = _current_account()
|
||||
assert account.current_tenant_id is not None
|
||||
features = FeatureService.get_features(account.current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class TestAccountInitialization:
|
|||
return "success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
|
|
@ -77,7 +77,7 @@ class TestAccountInitialization:
|
|||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
with patch("controllers.console.wraps._current_account", return_value=mock_user):
|
||||
with pytest.raises(AccountNotInitializedError):
|
||||
protected_view()
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
|
|||
return "member_added"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = add_member()
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
add_member()
|
||||
|
|
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Test 1: Should reject when source is datasets
|
||||
with app.test_request_context("/?source=datasets"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
upload_document()
|
||||
|
|
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
|
|||
|
||||
# Test 2: Should allow when source is not datasets
|
||||
with app.test_request_context("/?source=other"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = upload_document()
|
||||
assert result == "document_uploaded"
|
||||
|
|
@ -239,7 +239,7 @@ class TestRateLimiting:
|
|||
return "knowledge_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
|
|
@ -271,7 +271,7 @@ class TestRateLimiting:
|
|||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
|
|
|
|||
Loading…
Reference in New Issue