diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index fec527e4cb..b1e3813f33 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,5 +1,4 @@ import flask_restx -from flask_login import current_user from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus from sqlalchemy import select @@ -8,7 +7,8 @@ from werkzeug.exceptions import Forbidden from extensions.ext_database import db from libs.helper import TimestampField -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from models.dataset import Dataset from models.model import ApiToken, App @@ -57,6 +57,8 @@ class BaseApiKeyListResource(Resource): def get(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) keys = db.session.scalars( select(ApiToken).where( @@ -69,8 +71,10 @@ class BaseApiKeyListResource(Resource): def post(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() current_key_count = ( @@ -108,6 +112,8 @@ class BaseApiKeyResource(Resource): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index e489b48c82..c0d104e0d4 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -1,9 +1,9 @@ from flask import request -from flask_login import current_user from flask_restx import Resource, reqparse from libs.helper import extract_remote_ip -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from services.billing_service import BillingService from .. import console_ns @@ -17,6 +17,8 @@ class ComplianceApi(Resource): @account_initialization_required @only_edition_cloud def get(self): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None parser = reqparse.RequestParser() parser.add_argument("doc_name", type=str, required=True, location="args") args = parser.parse_args() diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index a68e337135..6113f1fd17 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,5 @@ import logging -from typing import cast -from flask_login import current_user from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -21,6 +19,7 @@ from core.errors.error import ( ) from core.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields +from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -31,6 +30,7 @@ logger = logging.getLogger(__name__) class DatasetsHitTestingBase: @staticmethod def get_and_validate_dataset(dataset_id: str): + assert isinstance(current_user, Account) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") @@ -57,11 +57,12 @@ class DatasetsHitTestingBase: @staticmethod def perform_hit_testing(dataset, args): + assert isinstance(current_user, Account) try: response = HitTestingService.retrieve( dataset=dataset, query=args["query"], - account=cast(Account, current_user), + account=current_user, retrieval_model=args["retrieval_model"], external_retrieval_model=args["external_retrieval_model"], limit=10, diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 3a8ba64a03..5956eb52c4 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -2,15 +2,15 @@ from collections.abc import Callable from functools import wraps from typing import Concatenate, ParamSpec, TypeVar -from flask_login import current_user from flask_restx import Resource from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from libs.login import login_required +from libs.login import current_user, login_required from models import InstalledApp +from models.account import Account from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -24,6 +24,8 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non def decorator(view: Callable[Concatenate[InstalledApp, P], R]): @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None installed_app = ( db.session.query(InstalledApp) .where( @@ -56,6 +58,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs): feature = FeatureService.get_system_features() if feature.webapp_auth.enabled: + assert isinstance(current_user, Account) app_id = installed_app.app_id app_code = AppService.get_app_code_by_id(app_id) res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 57f5ab191e..c6b3cf7515 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,11 +1,11 @@ -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from constants import HIDDEN_VALUE from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService @@ -47,6 +47,8 @@ class APIBasedExtensionAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def get(self): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None tenant_id = current_user.current_tenant_id return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) @@ -68,6 +70,8 @@ class APIBasedExtensionAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def post(self): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("api_endpoint", type=str, required=True, location="json") @@ -95,6 +99,8 @@ class APIBasedExtensionDetailAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def get(self, id): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None api_based_extension_id = str(id) tenant_id = current_user.current_tenant_id @@ -119,6 +125,8 @@ class APIBasedExtensionDetailAPI(Resource): @account_initialization_required @marshal_with(api_based_extension_fields) def post(self, id): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None api_based_extension_id = str(id) tenant_id = current_user.current_tenant_id @@ -146,6 +154,8 @@ class APIBasedExtensionDetailAPI(Resource): @login_required @account_initialization_required def delete(self, id): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None api_based_extension_id = str(id) tenant_id = current_user.current_tenant_id diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index d43b839291..80847b8fef 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,7 +1,7 @@ -from flask_login import current_user from flask_restx import Resource, fields -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from services.feature_service import FeatureService from . import api, console_ns @@ -23,6 +23,8 @@ class FeatureApi(Resource): @cloud_utm_record def get(self): """Get feature configuration for current tenant""" + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None return FeatureService.get_features(current_user.current_tenant_id).model_dump() diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 7aaf807fb0..4d4bb5d779 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,8 +1,6 @@ import urllib.parse -from typing import cast import httpx -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse import services @@ -16,6 +14,7 @@ from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from libs.login import current_user from models.account import Account from services.file_service import FileService @@ -65,7 +64,8 @@ class RemoteFileUploadApi(Resource): content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: - user = cast(Account, current_user) + assert isinstance(current_user, Account) + user = current_user upload_file = FileService(db.engine).upload_file( filename=file_info.filename, content=content, diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 3d29b3ee61..b6086c5766 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,12 +1,12 @@ from flask import request -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.tag_fields import dataset_tag_fields -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from models.model import Tag from services.tag_service import TagService @@ -24,6 +24,8 @@ class TagListApi(Resource): @account_initialization_required @marshal_with(dataset_tag_fields) def get(self): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) @@ -34,8 +36,10 @@ class TagListApi(Resource): @login_required @account_initialization_required def post(self): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() parser = reqparse.RequestParser() @@ -59,9 +63,11 @@ class TagUpdateDeleteApi(Resource): @login_required @account_initialization_required def patch(self, tag_id): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None tag_id = str(tag_id) # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() parser = reqparse.RequestParser() @@ -81,9 +87,11 @@ class TagUpdateDeleteApi(Resource): @login_required @account_initialization_required def delete(self, tag_id): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None tag_id = str(tag_id) # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: + if not current_user.has_edit_permission: raise Forbidden() TagService.delete_tag(tag_id) @@ -97,8 +105,10 @@ class TagBindingCreateApi(Resource): @login_required @account_initialization_required def post(self): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() parser = reqparse.RequestParser() @@ -123,8 +133,10 @@ class TagBindingDeleteApi(Resource): @login_required @account_initialization_required def post(self): + assert isinstance(current_user, Account) + assert current_user.current_tenant_id is not None # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.is_editor or current_user.is_dataset_editor): + if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() parser = reqparse.RequestParser() diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 0a2c8fcfb4..e044b2db5b 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,10 +1,10 @@ -from flask_login import current_user from flask_restx import Resource, fields from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from services.agent_service import AgentService @@ -21,7 +21,9 @@ class AgentProviderListApi(Resource): @login_required @account_initialization_required def get(self): + assert isinstance(current_user, Account) user = current_user + assert user.current_tenant_id is not None user_id = user.id tenant_id = user.current_tenant_id @@ -43,7 +45,9 @@ class AgentProviderApi(Resource): @login_required @account_initialization_required def get(self, provider_name: str): + assert isinstance(current_user, Account) user = current_user + assert user.current_tenant_id is not None user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 0657b764cc..782bd72565 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import Forbidden @@ -6,10 +5,18 @@ from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from services.plugin.endpoint_service import EndpointService +def _current_account_with_tenant() -> tuple[Account, str]: + assert isinstance(current_user, Account) + tenant_id = current_user.current_tenant_id + assert tenant_id is not None + return current_user, tenant_id + + @console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): @api.doc("create_endpoint") @@ -34,7 +41,7 @@ class EndpointCreateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = _current_account_with_tenant() if not user.is_admin_or_owner: raise Forbidden() @@ -51,7 +58,7 @@ class EndpointCreateApi(Resource): try: return { "success": EndpointService.create_endpoint( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, plugin_unique_identifier=plugin_unique_identifier, name=name, @@ -80,7 +87,7 @@ class EndpointListApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = _current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -93,7 +100,7 @@ class EndpointListApi(Resource): return jsonable_encoder( { "endpoints": EndpointService.list_endpoints( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, page=page, page_size=page_size, @@ -123,7 +130,7 @@ class EndpointListForSinglePluginApi(Resource): @login_required @account_initialization_required def get(self): - user = current_user + user, tenant_id = _current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("page", type=int, required=True, location="args") @@ -138,7 +145,7 @@ class EndpointListForSinglePluginApi(Resource): return jsonable_encoder( { "endpoints": EndpointService.list_endpoints_for_single_plugin( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, plugin_id=plugin_id, page=page, @@ -165,7 +172,7 @@ class EndpointDeleteApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = _current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -177,9 +184,7 @@ class EndpointDeleteApi(Resource): endpoint_id = args["endpoint_id"] return { - "success": EndpointService.delete_endpoint( - tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id - ) + "success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) } @@ -207,7 +212,7 @@ class EndpointUpdateApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = _current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -224,7 +229,7 @@ class EndpointUpdateApi(Resource): return { "success": EndpointService.update_endpoint( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id, name=name, @@ -250,7 +255,7 @@ class EndpointEnableApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = _current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -262,9 +267,7 @@ class EndpointEnableApi(Resource): raise Forbidden() return { - "success": EndpointService.enable_endpoint( - tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id - ) + "success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) } @@ -285,7 +288,7 @@ class EndpointDisableApi(Resource): @login_required @account_initialization_required def post(self): - user = current_user + user, tenant_id = _current_account_with_tenant() parser = reqparse.RequestParser() parser.add_argument("endpoint_id", type=str, required=True) @@ -297,7 +300,5 @@ class EndpointDisableApi(Resource): raise Forbidden() return { - "success": EndpointService.disable_endpoint( - tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id - ) + "success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id) } diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 8b89853bd9..dd6a878d87 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,7 +1,6 @@ from urllib import parse from flask import abort, request -from flask_login import current_user from flask_restx import Resource, marshal_with, reqparse import services @@ -26,7 +25,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields from libs.helper import extract_remote_ip -from libs.login import login_required +from libs.login import current_user, login_required from models.account import Account, TenantAccountRole from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index bc748ac3d2..4a0539785a 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,7 +1,6 @@ import logging from flask import request -from flask_login import current_user from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from werkzeug.exceptions import Unauthorized @@ -24,7 +23,7 @@ from controllers.console.wraps import ( ) from extensions.ext_database import db from libs.helper import TimestampField -from libs.login import login_required +from libs.login import current_user, login_required from models.account import Account, Tenant, TenantStatus from services.account_service import TenantService from services.feature_service import FeatureService diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 914d386c78..9e903d9286 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -7,13 +7,13 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import abort, request -from flask_login import current_user from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.account import AccountStatus +from libs.login import current_user +from models.account import Account, AccountStatus from models.dataset import RateLimitLog from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus @@ -25,11 +25,16 @@ P = ParamSpec("P") R = TypeVar("R") +def _current_account() -> Account: + assert isinstance(current_user, Account) + return current_user + + def account_initialization_required(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): # check account initialization - account = current_user + account = _current_account() if account.status == AccountStatus.UNINITIALIZED: raise AccountNotInitializedError() @@ -75,7 +80,9 @@ def only_edition_self_hosted(view: Callable[P, R]): def cloud_edition_billing_enabled(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - features = FeatureService.get_features(current_user.current_tenant_id) + account = _current_account() + assert account.current_tenant_id is not None + features = FeatureService.get_features(account.current_tenant_id) if not features.billing.enabled: abort(403, "Billing feature is not enabled.") return view(*args, **kwargs) @@ -87,7 +94,10 @@ def cloud_edition_billing_resource_check(resource: str): def interceptor(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - features = FeatureService.get_features(current_user.current_tenant_id) + account = _current_account() + assert account.current_tenant_id is not None + tenant_id = account.current_tenant_id + features = FeatureService.get_features(tenant_id) if features.billing.enabled: members = features.members apps = features.apps @@ -128,7 +138,9 @@ def cloud_edition_billing_knowledge_limit_check(resource: str): def interceptor(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - features = FeatureService.get_features(current_user.current_tenant_id) + account = _current_account() + assert account.current_tenant_id is not None + features = FeatureService.get_features(account.current_tenant_id) if features.billing.enabled: if resource == "add_segment": if features.billing.subscription.plan == "sandbox": @@ -151,10 +163,13 @@ def cloud_edition_billing_rate_limit_check(resource: str): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): if resource == "knowledge": - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id) + account = _current_account() + assert account.current_tenant_id is not None + tenant_id = account.current_tenant_id + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id) if knowledge_rate_limit.enabled: current_time = int(time.time() * 1000) - key = f"rate_limit_{current_user.current_tenant_id}" + key = f"rate_limit_{tenant_id}" redis_client.zadd(key, {current_time: current_time}) @@ -165,7 +180,7 @@ def cloud_edition_billing_rate_limit_check(resource: str): if request_count > knowledge_rate_limit.limit: # add ratelimit record rate_limit_log = RateLimitLog( - tenant_id=current_user.current_tenant_id, + tenant_id=tenant_id, subscription_plan=knowledge_rate_limit.subscription_plan, operation="knowledge", ) @@ -185,14 +200,17 @@ def cloud_utm_record(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): with contextlib.suppress(Exception): - features = FeatureService.get_features(current_user.current_tenant_id) + account = _current_account() + assert account.current_tenant_id is not None + tenant_id = account.current_tenant_id + features = FeatureService.get_features(tenant_id) if features.billing.enabled: utm_info = request.cookies.get("utm_info") if utm_info: utm_info_dict: dict = json.loads(utm_info) - OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) + OperationService.record_utm(tenant_id, utm_info_dict) return view(*args, **kwargs) @@ -271,7 +289,9 @@ def enable_change_email(view: Callable[P, R]): def is_allow_transfer_owner(view: Callable[P, R]): @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs): - features = FeatureService.get_features(current_user.current_tenant_id) + account = _current_account() + assert account.current_tenant_id is not None + features = FeatureService.get_features(account.current_tenant_id) if features.is_allow_transfer_workspace: return view(*args, **kwargs) @@ -284,7 +304,9 @@ def is_allow_transfer_owner(view: Callable[P, R]): def knowledge_pipeline_publish_enabled(view): @wraps(view) def decorated(*args, **kwargs): - features = FeatureService.get_features(current_user.current_tenant_id) + account = _current_account() + assert account.current_tenant_id is not None + features = FeatureService.get_features(account.current_tenant_id) if features.knowledge_pipeline.publish_enabled: return view(*args, **kwargs) abort(403) diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 9742368f04..5d132cb787 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -60,7 +60,7 @@ class TestAccountInitialization: return "success" # Act - with patch("controllers.console.wraps.current_user", mock_user): + with patch("controllers.console.wraps._current_account", return_value=mock_user): result = protected_view() # Assert @@ -77,7 +77,7 @@ class TestAccountInitialization: return "success" # Act & Assert - with patch("controllers.console.wraps.current_user", mock_user): + with patch("controllers.console.wraps._current_account", return_value=mock_user): with pytest.raises(AccountNotInitializedError): protected_view() @@ -163,7 +163,7 @@ class TestBillingResourceLimits: return "member_added" # Act - with patch("controllers.console.wraps.current_user"): + with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = add_member() @@ -185,7 +185,7 @@ class TestBillingResourceLimits: # Act & Assert with app.test_request_context(): - with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: add_member() @@ -207,7 +207,7 @@ class TestBillingResourceLimits: # Test 1: Should reject when source is datasets with app.test_request_context("/?source=datasets"): - with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): with pytest.raises(Exception) as exc_info: upload_document() @@ -215,7 +215,7 @@ class TestBillingResourceLimits: # Test 2: Should allow when source is not datasets with app.test_request_context("/?source=other"): - with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): result = upload_document() assert result == "document_uploaded" @@ -239,7 +239,7 @@ class TestRateLimiting: return "knowledge_success" # Act - with patch("controllers.console.wraps.current_user"): + with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): @@ -271,7 +271,7 @@ class TestRateLimiting: # Act & Assert with app.test_request_context(): - with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch("controllers.console.wraps._current_account", return_value=MockUser("test_user")): with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ):