refactor: use libs.login current_user in console controllers (#26745)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
AsperforMias 2025-10-13 10:33:33 +08:00 committed by GitHub
parent 24cd7bbc62
commit 2f50f3fd4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 134 additions and 73 deletions

View File

@ -1,5 +1,4 @@
import flask_restx
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import select
@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Dataset
from models.model import ApiToken, App
@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource):
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = db.session.scalars(
select(ApiToken).where(
@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
current_key_count = (
@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner

View File

@ -1,9 +1,9 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.billing_service import BillingService
from .. import console_ns
@ -17,6 +17,8 @@ class ComplianceApi(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()

View File

@ -1,7 +1,5 @@
import logging
from typing import cast
from flask_login import current_user
from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -21,6 +19,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import current_user
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -57,11 +57,12 @@ class DatasetsHitTestingBase:
@staticmethod
def perform_hit_testing(dataset, args):
assert isinstance(current_user, Account)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=cast(Account, current_user),
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,

View File

@ -2,15 +2,15 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from libs.login import current_user, login_required
from models import InstalledApp
from models.account import Account
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
installed_app = (
db.session.query(InstalledApp)
.where(
@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
assert isinstance(current_user, Account)
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(

View File

@ -1,11 +1,11 @@
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
parser.add_argument("api_endpoint", type=str, required=True, location="json")
@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource):
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource):
@login_required
@account_initialization_required
def delete(self, id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id

View File

@ -1,7 +1,7 @@
from flask_login import current_user
from flask_restx import Resource, fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.feature_service import FeatureService
from . import api, console_ns
@ -23,6 +23,8 @@ class FeatureApi(Resource):
@cloud_utm_record
def get(self):
"""Get feature configuration for current tenant"""
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
return FeatureService.get_features(current_user.current_tenant_id).model_dump()

View File

@ -1,8 +1,6 @@
import urllib.parse
from typing import cast
import httpx
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
@ -16,6 +14,7 @@ from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from libs.login import current_user
from models.account import Account
from services.file_service import FileService
@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user = cast(Account, current_user)
assert isinstance(current_user, Account)
user = current_user
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,

View File

@ -1,12 +1,12 @@
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from models.model import Tag
from services.tag_service import TagService
@ -24,6 +24,8 @@ class TagListApi(Resource):
@account_initialization_required
@marshal_with(dataset_tag_fields)
def get(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
@ -34,8 +36,10 @@ class TagListApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def patch(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, tag_id):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
if not current_user.has_edit_permission:
raise Forbidden()
TagService.delete_tag(tag_id)
@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()
@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.is_editor or current_user.is_dataset_editor):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
parser = reqparse.RequestParser()

View File

@ -1,10 +1,10 @@
from flask_login import current_user
from flask_restx import Resource, fields
from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.agent_service import AgentService
@ -21,7 +21,9 @@ class AgentProviderListApi(Resource):
@login_required
@account_initialization_required
def get(self):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
@ -43,7 +45,9 @@ class AgentProviderApi(Resource):
@login_required
@account_initialization_required
def get(self, provider_name: str):
assert isinstance(current_user, Account)
user = current_user
assert user.current_tenant_id is not None
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))

View File

@ -1,4 +1,3 @@
from flask_login import current_user
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
@ -6,10 +5,18 @@ from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account
from services.plugin.endpoint_service import EndpointService
def _current_account_with_tenant() -> tuple[Account, str]:
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
return current_user, tenant_id
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@api.doc("create_endpoint")
@ -34,7 +41,7 @@ class EndpointCreateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
@ -51,7 +58,7 @@ class EndpointCreateApi(Resource):
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=plugin_unique_identifier,
name=name,
@ -80,7 +87,7 @@ class EndpointListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -93,7 +100,7 @@ class EndpointListApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
page=page,
page_size=page_size,
@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource):
return jsonable_encoder(
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
page=page,
@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource):
endpoint_id = args["endpoint_id"]
return {
"success": EndpointService.delete_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource):
return {
"success": EndpointService.update_endpoint(
tenant_id=user.current_tenant_id,
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=name,
@ -250,7 +255,7 @@ class EndpointEnableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -262,9 +267,7 @@ class EndpointEnableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.enable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}
@ -285,7 +288,7 @@ class EndpointDisableApi(Resource):
@login_required
@account_initialization_required
def post(self):
user = current_user
user, tenant_id = _current_account_with_tenant()
parser = reqparse.RequestParser()
parser.add_argument("endpoint_id", type=str, required=True)
@ -297,7 +300,5 @@ class EndpointDisableApi(Resource):
raise Forbidden()
return {
"success": EndpointService.disable_endpoint(
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
)
"success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
}

View File

@ -1,7 +1,6 @@
from urllib import parse
from flask import abort, request
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
import services
@ -26,7 +25,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError

View File

@ -1,7 +1,6 @@
import logging
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
@ -24,7 +23,7 @@ from controllers.console.wraps import (
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
from libs.login import current_user, login_required
from models.account import Account, Tenant, TenantStatus
from services.account_service import TenantService
from services.feature_service import FeatureService

View File

@ -7,13 +7,13 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import AccountStatus
from libs.login import current_user
from models.account import Account, AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@ -25,11 +25,16 @@ P = ParamSpec("P")
R = TypeVar("R")
def _current_account() -> Account:
assert isinstance(current_user, Account)
return current_user
def account_initialization_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
# check account initialization
account = current_user
account = _current_account()
if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]):
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == "sandbox":
@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"
key = f"rate_limit_{tenant_id}"
redis_client.zadd(key, {current_time: current_time})
@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=current_user.current_tenant_id,
tenant_id=tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
tenant_id = account.current_tenant_id
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
OperationService.record_utm(tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]):
def knowledge_pipeline_publish_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
account = _current_account()
assert account.current_tenant_id is not None
features = FeatureService.get_features(account.current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)

View File

@ -60,7 +60,7 @@ class TestAccountInitialization:
return "success"
# Act
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
result = protected_view()
# Assert
@ -77,7 +77,7 @@ class TestAccountInitialization:
return "success"
# Act & Assert
with patch("controllers.console.wraps.current_user", mock_user):
with patch("controllers.console.wraps._current_account", return_value=mock_user):
with pytest.raises(AccountNotInitializedError):
protected_view()
@ -163,7 +163,7 @@ class TestBillingResourceLimits:
return "member_added"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = add_member()
@ -185,7 +185,7 @@ class TestBillingResourceLimits:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
add_member()
@ -207,7 +207,7 @@ class TestBillingResourceLimits:
# Test 1: Should reject when source is datasets
with app.test_request_context("/?source=datasets"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
upload_document()
@ -215,7 +215,7 @@ class TestBillingResourceLimits:
# Test 2: Should allow when source is not datasets
with app.test_request_context("/?source=other"):
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
result = upload_document()
assert result == "document_uploaded"
@ -239,7 +239,7 @@ class TestRateLimiting:
return "knowledge_success"
# Act
with patch("controllers.console.wraps.current_user"):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
@ -271,7 +271,7 @@ class TestRateLimiting:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")):
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):