diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index e8e8234ac4..cbdcdc8f10 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -9,9 +9,11 @@ from controllers.console.wraps import ( cloud_edition_billing_resource_check, edit_permission_required, setup_required, + with_current_user, ) from extensions.ext_database import db -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models.account import Account from models.model import App from services.app_dsl_service import AppDslService, Import from services.enterprise.enterprise_service import EnterpriseService @@ -48,9 +50,9 @@ class AppImportApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") @edit_permission_required - def post(self): + @with_current_user + def post(self, current_user: Account): # Check user role first - current_user, _ = current_account_with_tenant() args = AppImportPayload.model_validate(console_ns.payload) # AppDslService performs internal commits for some creation paths, so use a plain @@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, import_id: str): + @with_current_user + def post(self, current_user: Account, import_id: str): # Check user role first - current_user, _ = current_account_with_tenant() - with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Confirm import diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index a5b1a8c77d..db2ca624f1 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_user, +) from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( @@ -31,8 +36,9 @@ from fields.conversation_fields import ( ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse, ) from libs.datetime_utils import naive_utc_now, parse_time_range -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models import Conversation, EndUser, Message, MessageAnnotation +from models.account import Account from models.model import App, AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError @@ -93,8 +99,8 @@ class CompletionConversationApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def get(self, app_model: App): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, app_model: App): args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) query = sa.select(Conversation).where( @@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def get(self, app_model: App, conversation_id: UUID): + @with_current_user + def get(self, current_user: Account, app_model: App, conversation_id: UUID): conversation_id_str = str(conversation_id) return ConversationMessageDetailResponse.model_validate( - _get_conversation(app_model, conversation_id_str), from_attributes=True + _get_conversation(current_user, app_model, conversation_id_str), from_attributes=True ).model_dump(mode="json") @console_ns.doc("delete_completion_conversation") @@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) @edit_permission_required - def delete(self, app_model: App, conversation_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def delete(self, current_user: Account, app_model: App, conversation_id: UUID): conversation_id_str = str(conversation_id) try: @@ -207,8 +214,8 @@ class ChatConversationApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @edit_permission_required - def get(self, app_model: App): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, app_model: App): args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) subquery = ( @@ -318,10 +325,11 @@ class ChatConversationDetailApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @edit_permission_required - def get(self, app_model: App, conversation_id: UUID): + @with_current_user + def get(self, current_user: Account, app_model: App, conversation_id: UUID): conversation_id_str = str(conversation_id) return ConversationDetailResponse.model_validate( - _get_conversation(app_model, conversation_id_str), from_attributes=True + _get_conversation(current_user, app_model, conversation_id_str), from_attributes=True ).model_dump(mode="json") @console_ns.doc("delete_chat_conversation") @@ -335,8 +343,8 @@ class ChatConversationDetailApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @account_initialization_required @edit_permission_required - def delete(self, app_model: App, conversation_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def delete(self, current_user: Account, app_model: App, conversation_id: UUID): conversation_id_str = str(conversation_id) try: @@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource): return "", 204 -def _get_conversation(app_model, conversation_id): - current_user, _ = current_account_with_tenant() +def _get_conversation(current_user: Account, app_model, conversation_id): conversation = db.session.scalar( sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1) ) diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 15b3437bf9..833de38f34 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -25,6 +25,7 @@ from controllers.console.wraps import ( account_initialization_required, edit_permission_required, setup_required, + with_current_user, ) from core.app.entities.app_invoke_entities import InvokeFrom from core.entities.execution_extra_content import ExecutionExtraContentDomainModel @@ -43,7 +44,8 @@ from fields.conversation_fields import ( from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import to_timestamp, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models.account import Account from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError @@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_model: App): - current_user, _ = current_account_with_tenant() - + @with_current_user + def post(self, current_user: Account, app_model: App): args = MessageFeedbackPayload.model_validate(console_ns.payload) message_id = str(args.message_id) @@ -337,8 +338,8 @@ class MessageSuggestedQuestionApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def get(self, app_model: App, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, app_model: App, message_id: UUID): message_id_str = str(message_id) try: diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index ca7f194a35..df398fa7b9 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -14,12 +14,14 @@ from controllers.console.wraps import ( edit_permission_required, is_admin_or_owner_required, setup_required, + with_current_user, ) from extensions.ext_database import db from fields.base import ResponseModel from libs.datetime_utils import naive_utc_now -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models import Site +from models.account import Account from models.model import App @@ -85,9 +87,9 @@ class AppSite(Resource): @edit_permission_required @account_initialization_required @get_app_model - def post(self, app_model: App): + @with_current_user + def post(self, current_user: Account, app_model: App): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) - current_user, _ = current_account_with_tenant() site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise NotFound @@ -134,8 +136,8 @@ class AppSiteAccessTokenReset(Resource): @is_admin_or_owner_required @account_initialization_required @get_app_model - def post(self, app_model: App): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, app_model: App): site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index dd3cf273a2..9e3394e6c1 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -8,13 +8,14 @@ from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_user from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.datetime_utils import parse_time_range from libs.helper import convert_datetime_to_date -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models import AppMode +from models.account import Account from models.model import App @@ -48,9 +49,8 @@ class DailyMessageStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") @@ -109,9 +109,8 @@ class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") @@ -169,9 +168,8 @@ class DailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") @@ -230,9 +228,8 @@ class DailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") @@ -294,9 +291,8 @@ class AverageSessionInteractionStatistic(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("c.created_at") @@ -374,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("m.created_at") @@ -444,9 +439,8 @@ class AverageResponseTimeStatistic(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") @@ -505,8 +499,8 @@ class TokensPerSecondStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() + @with_current_user + def get(self, account: Account, app_model: App): args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 7b5a628561..05d579527e 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -6,10 +6,11 @@ from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_user from extensions.ext_database import db from libs.datetime_utils import parse_time_range -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models.account import Account from models.enums import WorkflowRunTriggeredFrom from models.model import App, AppMode from repositories.factory import DifyAPIRepositoryFactory @@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - def get(self, app_model: App): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account, app_model: App): args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 712fce2aa9..6a1b4c6769 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -32,11 +32,11 @@ from controllers.console.wraps import ( decrypt_password_field, email_password_login_enabled, setup_required, + with_current_user, ) from events.tenant_event import tenant_was_created from libs.helper import EmailStr, extract_remote_ip from libs.helper import timezone as validate_timezone_string -from libs.login import current_account_with_tenant from libs.token import ( clear_access_token_from_cookie, clear_csrf_token_from_cookie, @@ -46,6 +46,7 @@ from libs.token import ( set_csrf_token_to_cookie, set_refresh_token_to_cookie, ) +from models.account import Account from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService from services.billing_service import BillingService from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase @@ -172,9 +173,8 @@ class LoginApi(Resource): class LogoutApi(Resource): @setup_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() - account = current_user + @with_current_user + def post(self, account: Account): if isinstance(account, flask_login.AnonymousUserMixin): response = make_response({"result": "success"}) else: diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index cf516aa63b..5b53c40ae9 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound from controllers.common.controller_schemas import MetadataUpdatePayload from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, + with_current_user, +) from fields.dataset_fields import ( DatasetMetadataBuiltInFieldsResponse, DatasetMetadataListResponse, DatasetMetadataResponse, ) from libs.helper import dump_response -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models.account import Account from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( DocumentMetadataOperation, @@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource): @enterprise_license_required @console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataArgs.__name__]) - def post(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, dataset_id: UUID): metadata_args = MetadataArgs.model_validate(console_ns.payload or {}) dataset_id_str = str(dataset_id) @@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource): @enterprise_license_required @console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__]) - def patch(self, dataset_id: UUID, metadata_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID): payload = MetadataUpdatePayload.model_validate(console_ns.payload or {}) name = payload.name @@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource): @account_initialization_required @enterprise_license_required @console_ns.response(204, "Metadata deleted successfully") - def delete(self, dataset_id: UUID, metadata_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID): dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @account_initialization_required @enterprise_license_required @console_ns.response(204, "Action completed successfully") - def post(self, dataset_id: UUID, action: Literal["enable", "disable"]): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource): 204, "Documents metadata updated successfully", ) - def post(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index 3ae5d308c2..1937bd9781 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -10,6 +10,7 @@ from controllers.console.wraps import ( account_initialization_required, edit_permission_required, setup_required, + with_current_user, ) from extensions.ext_database import db from fields.rag_pipeline_fields import ( @@ -17,7 +18,8 @@ from fields.rag_pipeline_fields import ( pipeline_import_check_dependencies_fields, pipeline_import_fields, ) -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models.account import Account from models.dataset import Pipeline from services.entities.dsl_entities import ImportStatus from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService @@ -62,9 +64,9 @@ class RagPipelineImportApi(Resource): @edit_permission_required @marshal_with(pipeline_import_model) @console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__]) - def post(self): + @with_current_user + def post(self, current_user: Account): # Check user role first - current_user, _ = current_account_with_tenant() payload = RagPipelineImportPayload.model_validate(console_ns.payload or {}) # Use a plain Session so that caught exceptions inside the service @@ -105,9 +107,8 @@ class RagPipelineImportConfirmApi(Resource): @account_initialization_required @edit_permission_required @marshal_with(pipeline_import_model) - def post(self, import_id: str): - current_user, _ = current_account_with_tenant() - + @with_current_user + def post(self, current_user: Account, import_id: str): with Session(db.engine, expire_on_commit=False) as session: import_service = RagPipelineDslService(session) account = current_user diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 93a830e8ee..6e343f98d7 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -25,12 +25,13 @@ from controllers.console.wraps import ( account_initialization_required, is_allow_transfer_owner, setup_required, + with_current_user, ) from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.member_fields import AccountWithRole, AccountWithRoleList from libs.helper import extract_remote_ip -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models.account import Account, TenantAccountJoin, TenantAccountRole from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError @@ -136,8 +137,8 @@ class MemberListApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) - def get(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account): if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) @@ -154,7 +155,8 @@ class MemberInviteEmailApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = MemberInvitePayload.model_validate(payload) @@ -163,7 +165,6 @@ class MemberInviteEmailApi(Resource): interface_language = args.language if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - current_user, _ = current_account_with_tenant() inviter = current_user if not inviter.current_tenant: raise ValueError("No current tenant") @@ -223,8 +224,8 @@ class MemberCancelInviteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, member_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def delete(self, current_user: Account, member_id: UUID): if not current_user.current_tenant: raise ValueError("No current tenant") member = db.session.get(Account, str(member_id)) @@ -256,14 +257,14 @@ class MemberUpdateRoleApi(Resource): @setup_required @login_required @account_initialization_required - def put(self, member_id: UUID): + @with_current_user + def put(self, current_user: Account, member_id: UUID): payload = console_ns.payload or {} args = MemberRoleUpdatePayload.model_validate(payload) new_role = args.role if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 - current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") if not _is_role_enabled(new_role, current_user.current_tenant.id): @@ -297,8 +298,8 @@ class DatasetOperatorMemberListApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) - def get(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account): if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) @@ -317,13 +318,13 @@ class SendOwnerTransferEmailApi(Resource): @login_required @account_initialization_required @is_allow_transfer_owner - def post(self): + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = OwnerTransferEmailPayload.model_validate(payload) ip_address = extract_remote_ip(request) if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() - current_user, _ = current_account_with_tenant() # check if the current user is the owner of the workspace if not current_user.current_tenant: raise ValueError("No current tenant") @@ -355,11 +356,11 @@ class OwnerTransferCheckApi(Resource): @login_required @account_initialization_required @is_allow_transfer_owner - def post(self): + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = OwnerTransferCheckPayload.model_validate(payload) # check if the current user is the owner of the workspace - current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): @@ -399,12 +400,12 @@ class OwnerTransfer(Resource): @login_required @account_initialization_required @is_allow_transfer_owner - def post(self, member_id: UUID): + @with_current_user + def post(self, current_user: Account, member_id: UUID): payload = console_ns.payload or {} args = OwnerTransferPayload.model_validate(payload) # check if the current user is the owner of the workspace - current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index b13bdba2bc..5edae75f52 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -400,15 +400,10 @@ class TestSiteEndpoints: "session", MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) - monkeypatch.setattr( - site_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(id="u1"), "t1"), - ) monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") with app.test_request_context("/", json={"title": "My Site"}): - result = method(app_model=SimpleNamespace(id="app-1")) + result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1")) assert isinstance(result, dict) assert result["title"] == "My Site" @@ -439,15 +434,10 @@ class TestSiteEndpoints: MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None), ) monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") - monkeypatch.setattr( - site_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(id="u1"), "t1"), - ) monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") with app.test_request_context("/"): - result = method(app_model=SimpleNamespace(id="app-1")) + result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1")) assert isinstance(result, dict) assert result["access_token"] == "code" @@ -586,11 +576,6 @@ class TestWorkflowStatisticEndpoints: "create_api_workflow_run_repository", lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]), ) - monkeypatch.setattr( - workflow_statistic_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(timezone="UTC"), "t1"), - ) monkeypatch.setattr( workflow_statistic_module, "parse_time_range", @@ -601,7 +586,7 @@ class TestWorkflowStatisticEndpoints: method = _unwrap(api.get) with app.test_request_context("/"): - response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-01"}]} @@ -614,11 +599,6 @@ class TestWorkflowStatisticEndpoints: get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}] ), ) - monkeypatch.setattr( - workflow_statistic_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(timezone="UTC"), "t1"), - ) monkeypatch.setattr( workflow_statistic_module, "parse_time_range", @@ -629,7 +609,7 @@ class TestWorkflowStatisticEndpoints: method = _unwrap(api.get) with app.test_request_context("/"): - response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-02"}]} diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index bcb6e41ef7..520ee67ee0 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -50,10 +50,9 @@ class TestAppImportApi: "import_app", lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) assert status == 400 assert response["status"] == ImportStatus.FAILED @@ -68,10 +67,9 @@ class TestAppImportApi: "import_app", lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) assert status == 202 assert response["status"] == ImportStatus.PENDING @@ -88,10 +86,9 @@ class TestAppImportApi: ) update_access = MagicMock() monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) update_access.assert_called_once_with("app-123", "private") assert status == 200 @@ -107,7 +104,6 @@ class TestAppImportApi: "import_app", lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) fake_session = MagicMock() fake_session.__enter__.return_value = fake_session @@ -115,7 +111,7 @@ class TestAppImportApi: monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) fake_session.commit.assert_called_once_with() fake_session.rollback.assert_not_called() @@ -132,7 +128,6 @@ class TestAppImportApi: "import_app", lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) fake_session = MagicMock() fake_session.__enter__.return_value = fake_session @@ -140,7 +135,7 @@ class TestAppImportApi: monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) fake_session.rollback.assert_called_once_with() fake_session.commit.assert_not_called() @@ -162,10 +157,9 @@ class TestAppImportConfirmApi: "confirm_import", lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): - response, status = method(import_id="import-1") + response, status = method(SimpleNamespace(id="u1"), import_id="import-1") assert status == 400 assert response["status"] == ImportStatus.FAILED diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py index fad0b8b10e..05c8452da8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -36,19 +36,12 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged( read_at = datetime(2026, 2, 9, 0, 0, 0) - with ( - patch( - "controllers.console.app.conversation.current_account_with_tenant", - return_value=(account, tenant.id), - autospec=True, - ), - patch( - "controllers.console.app.conversation.naive_utc_now", - return_value=read_at, - autospec=True, - ), + with patch( + "controllers.console.app.conversation.naive_utc_now", + return_value=read_at, + autospec=True, ): - loaded = _get_conversation(app, conversation.id) + loaded = _get_conversation(account, app, conversation.id) db_session_with_containers.refresh(conversation) @@ -64,10 +57,5 @@ def test_get_conversation_raises_not_found_for_missing_conversation( account, tenant = create_console_account_and_tenant(db_session_with_containers) app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) - with patch( - "controllers.console.app.conversation.current_account_with_tenant", - return_value=(account, tenant.id), - autospec=True, - ): - with pytest.raises(NotFound): - _get_conversation(app, "00000000-0000-0000-0000-000000000000") + with pytest.raises(NotFound): + _get_conversation(account, app, "00000000-0000-0000-0000-000000000000") diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index 44eb5c336c..f6c4ebda3e 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -52,16 +52,12 @@ class TestRagPipelineImportApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, ), ): - response, status = method(api) + response, status = method(api, user) assert status == 200 assert response == {"status": "success"} @@ -82,16 +78,12 @@ class TestRagPipelineImportApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, ), ): - response, status = method(api) + response, status = method(api, user) assert status == 400 assert response == {"status": "failed"} @@ -112,16 +104,12 @@ class TestRagPipelineImportApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, ), ): - response, status = method(api) + response, status = method(api, user) assert status == 202 assert response == {"status": "pending"} @@ -146,16 +134,12 @@ class TestRagPipelineImportConfirmApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, ), ): - response, status = method(api, "import-1") + response, status = method(api, user, "import-1") assert status == 200 assert response == {"ok": True} @@ -174,16 +158,12 @@ class TestRagPipelineImportConfirmApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", return_value=service, ), ): - response, status = method(api, "import-1") + response, status = method(api, user, "import-1") assert status == 400 assert response == {"ok": False} diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py index f9c0c4d669..3c1688293e 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py @@ -104,10 +104,9 @@ class TestMemberCancelInviteApiWithContainers: with ( flask_app_with_containers.test_request_context("/"), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), patch.object(members_module.TenantService, "remove_member_from_tenant") as mock_remove_member, ): - result, status = method(api, member.id) + result, status = method(api, current_user, member.id) assert status == 200 assert result["result"] == "success" @@ -123,12 +122,9 @@ class TestMemberCancelInviteApiWithContainers: factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) - with ( - flask_app_with_containers.test_request_context("/"), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), - ): + with flask_app_with_containers.test_request_context("/"): with pytest.raises(HTTPException): - method(api, str(uuid4())) + method(api, current_user, str(uuid4())) def test_cancel_cannot_operate_self(self, flask_app_with_containers, db_session_with_containers): api = MemberCancelInviteApi() @@ -139,14 +135,13 @@ class TestMemberCancelInviteApiWithContainers: with ( flask_app_with_containers.test_request_context("/"), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), patch.object( members_module.TenantService, "remove_member_from_tenant", side_effect=services.errors.account.CannotOperateSelfError("x"), ), ): - result, status = method(api, member.id) + result, status = method(api, current_user, member.id) assert status == 400 assert result["code"] == "cannot-operate-self" @@ -160,14 +155,13 @@ class TestMemberCancelInviteApiWithContainers: with ( flask_app_with_containers.test_request_context("/"), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), patch.object( members_module.TenantService, "remove_member_from_tenant", side_effect=services.errors.account.NoPermissionError("x"), ), ): - result, status = method(api, member.id) + result, status = method(api, current_user, member.id) assert status == 403 assert result["code"] == "forbidden" @@ -181,14 +175,13 @@ class TestMemberCancelInviteApiWithContainers: with ( flask_app_with_containers.test_request_context("/"), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), patch.object( members_module.TenantService, "remove_member_from_tenant", side_effect=services.errors.account.MemberNotInTenantError(), ), ): - result, status = method(api, member.id) + result, status = method(api, current_user, member.id) assert status == 404 assert result["code"] == "member-not-found" @@ -207,11 +200,8 @@ class TestMemberUpdateRoleApiWithContainers: role=TenantAccountRole.EDITOR, ) - with ( - flask_app_with_containers.test_request_context("/", json={"role": "normal"}), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), - ): - result = method(api, member.id) + with flask_app_with_containers.test_request_context("/", json={"role": "normal"}): + result = method(api, current_user, member.id) if isinstance(result, tuple): result = result[0] @@ -227,12 +217,9 @@ class TestMemberUpdateRoleApiWithContainers: factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) - with ( - flask_app_with_containers.test_request_context("/", json={"role": "normal"}), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), - ): + with flask_app_with_containers.test_request_context("/", json={"role": "normal"}): with pytest.raises(HTTPException): - method(api, str(uuid4())) + method(api, current_user, str(uuid4())) class TestOwnerTransferApiWithContainers: @@ -244,12 +231,9 @@ class TestOwnerTransferApiWithContainers: member = factory.create_account(db_session_with_containers, email_prefix="member") token = factory.create_owner_transfer_token(current_user) - with ( - flask_app_with_containers.test_request_context("/", json={"token": token}), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), - ): + with flask_app_with_containers.test_request_context("/", json={"token": token}): with pytest.raises(MemberNotInTenantError): - method(api, member.id) + method(api, current_user, member.id) def test_member_not_found(self, flask_app_with_containers, db_session_with_containers): api = OwnerTransfer() @@ -258,12 +242,9 @@ class TestOwnerTransferApiWithContainers: tenant, current_user = factory.create_owner_workspace(db_session_with_containers) token = factory.create_owner_transfer_token(current_user) - with ( - flask_app_with_containers.test_request_context("/", json={"token": token}), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), - ): + with flask_app_with_containers.test_request_context("/", json={"token": token}): with pytest.raises(HTTPException): - method(api, str(uuid4())) + method(api, current_user, str(uuid4())) def test_transfer_success(self, flask_app_with_containers, db_session_with_containers): api = OwnerTransfer() @@ -280,11 +261,10 @@ class TestOwnerTransferApiWithContainers: with ( flask_app_with_containers.test_request_context("/", json={"token": token}), - patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)), patch.object(members_module.AccountService, "send_new_owner_transfer_notify_email") as mock_new_owner_email, patch.object(members_module.AccountService, "send_old_owner_transfer_notify_email") as mock_old_owner_email, ): - result = method(api, member.id) + result = method(api, current_user, member.id) assert result["result"] == "success" assert ( diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py index e690968ffb..386f75e231 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -61,10 +61,9 @@ class TestAppImportApi: "import_app", lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) session.rollback.assert_called_once_with() session.commit.assert_not_called() @@ -83,10 +82,9 @@ class TestAppImportApi: "import_app", lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) session.commit.assert_called_once_with() session.rollback.assert_not_called() @@ -107,10 +105,9 @@ class TestAppImportApi: ) update_access = MagicMock() monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method() + response, status = method(SimpleNamespace(id="u1")) session.commit.assert_called_once_with() session.rollback.assert_not_called() @@ -135,10 +132,9 @@ class TestAppImportConfirmApi: "confirm_import", lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), ) - monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): - response, status = method(import_id="import-1") + response, status = method(SimpleNamespace(id="u1"), import_id="import-1") session.rollback.assert_called_once_with() session.commit.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 24b7e39f73..41924bbfd3 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -29,7 +29,6 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: method = _unwrap(api.get) account = _make_account() - monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() @@ -41,7 +40,7 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(account, app_model=SimpleNamespace(id="app-1")) assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} @@ -51,7 +50,6 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes method = _unwrap(api.get) account = _make_account() - monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) monkeypatch.setattr( conversation_module, "parse_time_range", @@ -64,7 +62,7 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes query_string={"start": "bad"}, ): with pytest.raises(BadRequest): - method(app_model=SimpleNamespace(id="app-1")) + method(account, app_model=SimpleNamespace(id="app-1")) def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -72,7 +70,6 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p method = _unwrap(api.get) account = _make_account() - monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() @@ -84,7 +81,7 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) + response = method(account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} @@ -95,10 +92,9 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No session = MagicMock() session.scalar.return_value = conversation - monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) - result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1") + result = conversation_module._get_conversation(_make_account(), SimpleNamespace(id="app-1"), "c1") assert result is conversation session.execute.assert_called_once() @@ -110,18 +106,16 @@ def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPat session = MagicMock() session.scalar.return_value = None - monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr(conversation_module.db, "session", session) with pytest.raises(NotFound): - conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing") + conversation_module._get_conversation(_make_account(), SimpleNamespace(id="app-1"), "missing") def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None: api = conversation_module.CompletionConversationDetailApi() method = _unwrap(api.delete) - monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) monkeypatch.setattr( conversation_module.ConversationService, "delete", @@ -129,4 +123,4 @@ def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.Monke ) with pytest.raises(NotFound): - method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1") + method(_make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1") diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py index 15459994f9..4093398341 100644 --- a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py @@ -38,11 +38,6 @@ def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None: def _install_common(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr( - statistic_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(timezone="UTC"), "t1"), - ) monkeypatch.setattr( statistic_module, "parse_time_range", @@ -60,7 +55,7 @@ def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPat _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]} @@ -74,7 +69,7 @@ def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.Monk _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} @@ -88,7 +83,7 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) data = response.get_json() assert len(data["data"]) == 1 @@ -106,7 +101,7 @@ def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyP _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]} @@ -128,17 +123,12 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes raise ValueError("Invalid time range") _install_db(monkeypatch, []) - monkeypatch.setattr( - statistic_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(timezone="UTC"), "t1"), - ) monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse) monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): with pytest.raises(BadRequest): - method(app_model=SimpleNamespace(id="app-1")) + method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -154,7 +144,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) data = response.get_json() assert len(data["data"]) == 3 @@ -168,7 +158,7 @@ def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPat _install_db(monkeypatch, []) with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": []} @@ -179,11 +169,6 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)] _install_db(monkeypatch, rows) - monkeypatch.setattr( - statistic_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(timezone="UTC"), "t1"), - ) monkeypatch.setattr( statistic_module, "parse_time_range", @@ -192,7 +177,7 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} @@ -209,7 +194,7 @@ def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.Monk _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) data = response.get_json() assert len(data["data"]) == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index ace2ce5706..906688d8c8 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -41,6 +41,15 @@ def encode_code(code: str) -> str: return base64.b64encode(code.encode("utf-8")).decode() +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + class TestLoginApi: """Test cases for the LoginApi endpoint.""" @@ -486,13 +495,9 @@ class TestLogoutApi: account.email = "test@example.com" return account - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.login.current_account_with_tenant") @patch("controllers.console.auth.login.AccountService.logout") @patch("controllers.console.auth.login.flask_login.logout_user") - def test_successful_logout( - self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account - ): + def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account): """ Test successful logout flow. @@ -502,23 +507,18 @@ class TestLogoutApi: - All authentication cookies are cleared - Success response is returned """ - # Arrange - mock_current_account.return_value = (mock_account, MagicMock()) - # Act with app.test_request_context("/logout", method="POST"): logout_api = LogoutApi() - response = logout_api.post() + response = _unwrap(logout_api.post)(mock_account) # Assert mock_service_logout.assert_called_once_with(account=mock_account) mock_logout_user.assert_called_once() assert response.json["result"] == "success" - @patch("controllers.console.wraps.db") - @patch("controllers.console.auth.login.current_account_with_tenant") @patch("controllers.console.auth.login.flask_login") - def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask): + def test_logout_anonymous_user(self, mock_flask_login, app: Flask): """ Test logout for anonymous (not logged in) user. @@ -532,12 +532,11 @@ class TestLogoutApi: anonymous_user = MagicMock() mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {}) anonymous_user.__class__ = mock_flask_login.AnonymousUserMixin - mock_current_account.return_value = (anonymous_user, None) # Act with app.test_request_context("/logout", method="POST"): logout_api = LogoutApi() - response = logout_api.post() + response = _unwrap(logout_api.post)(anonymous_user) # Assert assert response.json["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py index b2863fc8cd..3015ed6604 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -96,10 +96,6 @@ class TestDatasetMetadataCreateApi: new_callable=PropertyMock, return_value=payload, ), - patch( - "controllers.console.datasets.metadata.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( MetadataArgs, "model_validate", @@ -120,7 +116,7 @@ class TestDatasetMetadataCreateApi: return_value={"id": "m1", "type": "string", "name": "author"}, ), ): - result, status = method(api, dataset_id) + result, status = method(api, current_user, dataset_id) assert status == 201 assert result["type"] == "string" @@ -143,10 +139,6 @@ class TestDatasetMetadataCreateApi: new_callable=PropertyMock, return_value=valid_payload, ), - patch( - "controllers.console.datasets.metadata.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( MetadataArgs, "model_validate", @@ -159,7 +151,7 @@ class TestDatasetMetadataCreateApi: ), ): with pytest.raises(NotFound, match="Dataset not found"): - method(api, dataset_id) + method(api, current_user, dataset_id) class TestDatasetMetadataGetApi: @@ -220,10 +212,6 @@ class TestDatasetMetadataApi: new_callable=PropertyMock, return_value=payload, ), - patch( - "controllers.console.datasets.metadata.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -239,7 +227,7 @@ class TestDatasetMetadataApi: return_value={"id": "m1", "type": "string", "name": "updated-name"}, ), ): - result, status = method(api, dataset_id, metadata_id) + result, status = method(api, current_user, dataset_id, metadata_id) assert status == 200 assert result["type"] == "string" @@ -251,10 +239,6 @@ class TestDatasetMetadataApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.metadata.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -269,7 +253,7 @@ class TestDatasetMetadataApi: "delete_metadata", ), ): - result, status = method(api, dataset_id, metadata_id) + result, status = method(api, current_user, dataset_id, metadata_id) assert status == 204 assert result == "" @@ -307,10 +291,6 @@ class TestDatasetMetadataBuiltInFieldActionApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.metadata.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -325,7 +305,7 @@ class TestDatasetMetadataBuiltInFieldActionApi: "enable_built_in_field", ), ): - result, status = method(api, dataset_id, "enable") + result, status = method(api, current_user, dataset_id, "enable") assert status == 204 assert result == "" @@ -346,10 +326,6 @@ class TestDocumentMetadataEditApi: new_callable=PropertyMock, return_value=payload, ), - patch( - "controllers.console.datasets.metadata.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -369,7 +345,7 @@ class TestDocumentMetadataEditApi: "update_documents_metadata", ), ): - result, status = method(api, dataset_id) + result, status = method(api, current_user, dataset_id) assert status == 204 assert result == "" diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index c45271eb13..d8d1d02a16 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -1,6 +1,6 @@ from contextlib import nullcontext from types import SimpleNamespace -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest from flask import Flask, g @@ -39,24 +39,14 @@ class TestMemberInviteEmailApi: @patch("controllers.console.workspace.members.FeatureService.get_features") @patch("controllers.console.workspace.members.RegisterService.invite_new_member") - @patch("controllers.console.workspace.members.current_account_with_tenant") @patch("controllers.console.wraps.db") @patch("libs.login.check_csrf_token", return_value=None) - def test_invite_normalizes_emails( - self, - mock_csrf, - mock_db, - mock_current_account, - mock_invite_member, - mock_get_features, - app: Flask, - ): + def test_invite_normalizes_emails(self, mock_csrf, mock_db, mock_invite_member, mock_get_features, app: Flask): mock_get_features.return_value = _build_feature_flags() mock_invite_member.return_value = "token-abc" tenant = SimpleNamespace(id="tenant-1", name="Test Tenant") inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active") - mock_current_account.return_value = (inviter, tenant.id) with ( patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"), @@ -84,5 +74,5 @@ class TestMemberInviteEmailApi: assert call_args.kwargs["email"] == "user@example.com" assert call_args.kwargs["language"] == "en-US" assert call_args.kwargs["role"] == TenantAccountRole.EDITOR - assert call_args.kwargs["inviter"] == inviter - mock_csrf.assert_called_once() + assert call_args.kwargs["inviter"] == account + mock_csrf.assert_called_once_with(ANY, account.id) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index 38e745ee5e..494cbbf0c3 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -49,10 +49,9 @@ class TestMemberListApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members), ): - result, status = method(api) + result, status = method(api, user) assert status == 200 assert len(result["accounts"]) == 1 @@ -65,10 +64,9 @@ class TestMemberListApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), ): with pytest.raises(ValueError): - method(api) + method(api, user) class TestMemberInviteEmailApi: @@ -96,7 +94,6 @@ class TestMemberInviteEmailApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), patch("controllers.console.workspace.members._count_new_member_invites", return_value=1), patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"), @@ -104,7 +101,7 @@ class TestMemberInviteEmailApi: patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False), patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False), ): - result, status = method(api) + result, status = method(api, user) assert status == 201 assert result["result"] == "success" @@ -127,14 +124,13 @@ class TestMemberInviteEmailApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), patch("controllers.console.workspace.members._count_new_member_invites", return_value=1), patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", True), patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False), ): with pytest.raises(WorkspaceMembersLimitExceeded): - method(api) + method(api, user) def test_invite_billing_limit_exceeded(self, app: Flask): api = MemberInviteEmailApi() @@ -155,7 +151,6 @@ class TestMemberInviteEmailApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), patch("controllers.console.workspace.members._count_new_member_invites", return_value=2), patch("controllers.console.workspace.members._count_current_members", return_value=9), @@ -163,7 +158,7 @@ class TestMemberInviteEmailApi: patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", True), ): with pytest.raises(WorkspaceMembersLimitExceeded): - method(api) + method(api, user) def test_invite_already_member(self, app: Flask): api = MemberInviteEmailApi() @@ -183,7 +178,6 @@ class TestMemberInviteEmailApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), patch("controllers.console.workspace.members._count_new_member_invites", return_value=0), patch( @@ -194,7 +188,7 @@ class TestMemberInviteEmailApi: patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False), patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False), ): - result, status = method(api) + result, status = method(api, user) assert result["invitation_results"][0]["status"] == "success" @@ -208,7 +202,7 @@ class TestMemberInviteEmailApi: } with app.test_request_context("/", json=payload): - result, status = method(api) + result, status = method(api, MagicMock()) assert status == 400 assert result["code"] == "invalid-role" @@ -231,7 +225,6 @@ class TestMemberInviteEmailApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), patch("controllers.console.workspace.members._count_new_member_invites", return_value=1), patch( @@ -242,7 +235,7 @@ class TestMemberInviteEmailApi: patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False), patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False), ): - result, _ = method(api) + result, _ = method(api, user) assert result["invitation_results"][0]["status"] == "failed" @@ -255,7 +248,7 @@ class TestMemberUpdateRoleApi: payload = {"role": "invalid-role"} with app.test_request_context("/", json=payload): - result, status = method(api, "id") + result, status = method(api, MagicMock(), "id") assert status == 400 @@ -278,12 +271,11 @@ class TestDatasetOperatorMemberListApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch( "controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members ), ): - result, status = method(api) + result, status = method(api, user) assert status == 200 assert len(result["accounts"]) == 1 @@ -296,10 +288,9 @@ class TestDatasetOperatorMemberListApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), ): with pytest.raises(ValueError): - method(api) + method(api, user) class TestSendOwnerTransferEmailApi: @@ -316,13 +307,12 @@ class TestSendOwnerTransferEmailApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), patch( "controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token" ), ): - result = method(api) + result = method(api, user) assert result["result"] == "success" @@ -338,7 +328,7 @@ class TestSendOwnerTransferEmailApi: patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True), ): with pytest.raises(EmailSendIpLimitError): - method(api) + method(api, MagicMock()) def test_send_not_owner(self, app: Flask): api = SendOwnerTransferEmailApi() @@ -351,11 +341,10 @@ class TestSendOwnerTransferEmailApi: app.test_request_context("/", json={}), patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False), ): with pytest.raises(NotOwnerError): - method(api) + method(api, user) class TestOwnerTransferCheckApi: @@ -370,7 +359,6 @@ class TestOwnerTransferCheckApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), patch( "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", @@ -382,7 +370,7 @@ class TestOwnerTransferCheckApi: ), ): with pytest.raises(EmailCodeError): - method(api) + method(api, user) def test_rate_limited(self, app: Flask): api = OwnerTransferCheckApi() @@ -395,7 +383,6 @@ class TestOwnerTransferCheckApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), patch( "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", @@ -403,7 +390,7 @@ class TestOwnerTransferCheckApi: ), ): with pytest.raises(OwnerTransferLimitError): - method(api) + method(api, user) def test_invalid_token(self, app: Flask): api = OwnerTransferCheckApi() @@ -416,7 +403,6 @@ class TestOwnerTransferCheckApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), patch( "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", @@ -425,7 +411,7 @@ class TestOwnerTransferCheckApi: patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), ): with pytest.raises(InvalidTokenError): - method(api) + method(api, user) def test_invalid_email(self, app: Flask): api = OwnerTransferCheckApi() @@ -438,7 +424,6 @@ class TestOwnerTransferCheckApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), patch( "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", @@ -450,7 +435,7 @@ class TestOwnerTransferCheckApi: ), ): with pytest.raises(InvalidEmailError): - method(api) + method(api, user) class TestOwnerTransferApi: @@ -465,11 +450,10 @@ class TestOwnerTransferApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), ): with pytest.raises(CannotTransferOwnerToSelfError): - method(api, "1") + method(api, user, "1") def test_invalid_token(self, app: Flask): api = OwnerTransfer() @@ -482,9 +466,8 @@ class TestOwnerTransferApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), ): with pytest.raises(InvalidTokenError): - method(api, "2") + method(api, user, "2")