refactor: inject current user into user-only controllers (#36754)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
Tianle 2026-05-31 10:03:15 -05:00 committed by GitHub
parent d8571ce965
commit 0a3005701f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 220 additions and 365 deletions

View File

@ -9,9 +9,11 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.model import App
from services.app_dsl_service import AppDslService, Import
from services.enterprise.enterprise_service import EnterpriseService
@ -48,9 +50,9 @@ class AppImportApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@with_current_user
def post(self, current_user: Account):
# Check user role first
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
# AppDslService performs internal commits for some creation paths, so use a plain
@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
@with_current_user
def post(self, current_user: Account, import_id: str):
# Check user role first
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
import_service = AppDslService(session)
# Confirm import

View File

@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
@ -31,8 +36,9 @@ from fields.conversation_fields import (
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
)
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.account import Account
from models.model import App, AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
@ -93,8 +99,8 @@ class CompletionConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, app_model: App):
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
query = sa.select(Conversation).where(
@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model: App, conversation_id: UUID):
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationMessageDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_completion_conversation")
@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model: App, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
try:
@ -207,8 +214,8 @@ class ChatConversationApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, app_model: App):
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
subquery = (
@ -318,10 +325,11 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
def get(self, app_model: App, conversation_id: UUID):
@with_current_user
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
return ConversationDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_chat_conversation")
@ -335,8 +343,8 @@ class ChatConversationDetailApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model: App, conversation_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
conversation_id_str = str(conversation_id)
try:
@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource):
return "", 204
def _get_conversation(app_model, conversation_id):
current_user, _ = current_account_with_tenant()
def _get_conversation(current_user: Account, app_model, conversation_id):
conversation = db.session.scalar(
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
)

View File

@ -25,6 +25,7 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
@ -43,7 +44,8 @@ from fields.conversation_fields import (
from graphon.model_runtime.errors.invoke import InvokeError
from libs.helper import to_timestamp, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, app_model: App):
args = MessageFeedbackPayload.model_validate(console_ns.payload)
message_id = str(args.message_id)
@ -337,8 +338,8 @@ class MessageSuggestedQuestionApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model: App, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, app_model: App, message_id: UUID):
message_id_str = str(message_id)
try:

View File

@ -14,12 +14,14 @@ from controllers.console.wraps import (
edit_permission_required,
is_admin_or_owner_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Site
from models.account import Account
from models.model import App
@ -85,9 +87,9 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
def post(self, app_model: App):
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:
raise NotFound
@ -134,8 +136,8 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, app_model: App):
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
if not site:

View File

@ -8,13 +8,14 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import AppMode
from models.account import Account
from models.model import App
@ -48,9 +49,8 @@ class DailyMessageStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -109,9 +109,8 @@ class DailyConversationStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -169,9 +168,8 @@ class DailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -230,9 +228,8 @@ class DailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -294,9 +291,8 @@ class AverageSessionInteractionStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("c.created_at")
@ -374,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("m.created_at")
@ -444,9 +439,8 @@ class AverageResponseTimeStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")
@ -505,8 +499,8 @@ class TokensPerSecondStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
converted_created_at = convert_datetime_to_date("created_at")

View File

@ -6,10 +6,11 @@ from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None
@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def get(self, app_model: App):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account, app_model: App):
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
assert account.timezone is not None

View File

@ -32,11 +32,11 @@ from controllers.console.wraps import (
decrypt_password_field,
email_password_login_enabled,
setup_required,
with_current_user,
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
@ -46,6 +46,7 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models.account import Account
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
@ -172,9 +173,8 @@ class LoginApi(Resource):
class LogoutApi(Resource):
@setup_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
account = current_user
@with_current_user
def post(self, account: Account):
if isinstance(account, flask_login.AnonymousUserMixin):
response = make_response({"result": "success"})
else:

View File

@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
with_current_user,
)
from fields.dataset_fields import (
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Metadata deleted successfully")
def delete(self, dataset_id: UUID, metadata_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Action completed successfully")
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource):
204,
"Documents metadata updated successfully",
)
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -10,6 +10,7 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import (
@ -17,7 +18,8 @@ from fields.rag_pipeline_fields import (
pipeline_import_check_dependencies_fields,
pipeline_import_fields,
)
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from models.dataset import Pipeline
from services.entities.dsl_entities import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
@ -62,9 +64,9 @@ class RagPipelineImportApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_model)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self):
@with_current_user
def post(self, current_user: Account):
# Check user role first
current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Use a plain Session so that caught exceptions inside the service
@ -105,9 +107,8 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
def post(self, import_id: str):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, import_id: str):
with Session(db.engine, expire_on_commit=False) as session:
import_service = RagPipelineDslService(session)
account = current_user

View File

@ -25,12 +25,13 @@ from controllers.console.wraps import (
account_initialization_required,
is_allow_transfer_owner,
setup_required,
with_current_user,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account, TenantAccountJoin, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@ -136,8 +137,8 @@ class MemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
@ -154,7 +155,8 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
@ -163,7 +165,6 @@ class MemberInviteEmailApi(Resource):
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
@ -223,8 +224,8 @@ class MemberCancelInviteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, member_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, member_id: UUID):
if not current_user.current_tenant:
raise ValueError("No current tenant")
member = db.session.get(Account, str(member_id))
@ -256,14 +257,14 @@ class MemberUpdateRoleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self, member_id: UUID):
@with_current_user
def put(self, current_user: Account, member_id: UUID):
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not _is_role_enabled(new_role, current_user.current_tenant.id):
@ -297,8 +298,8 @@ class DatasetOperatorMemberListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
@ -317,13 +318,13 @@ class SendOwnerTransferEmailApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
current_user, _ = current_account_with_tenant()
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
@ -355,11 +356,11 @@ class OwnerTransferCheckApi(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
@ -399,12 +400,12 @@ class OwnerTransfer(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id: UUID):
@with_current_user
def post(self, current_user: Account, member_id: UUID):
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):

View File

@ -400,15 +400,10 @@ class TestSiteEndpoints:
"session",
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
)
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/", json={"title": "My Site"}):
result = method(app_model=SimpleNamespace(id="app-1"))
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
assert isinstance(result, dict)
assert result["title"] == "My Site"
@ -439,15 +434,10 @@ class TestSiteEndpoints:
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
)
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/"):
result = method(app_model=SimpleNamespace(id="app-1"))
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
assert isinstance(result, dict)
assert result["access_token"] == "code"
@ -586,11 +576,6 @@ class TestWorkflowStatisticEndpoints:
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
@ -601,7 +586,7 @@ class TestWorkflowStatisticEndpoints:
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
@ -614,11 +599,6 @@ class TestWorkflowStatisticEndpoints:
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
@ -629,7 +609,7 @@ class TestWorkflowStatisticEndpoints:
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}

View File

@ -50,10 +50,9 @@ class TestAppImportApi:
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
assert status == 400
assert response["status"] == ImportStatus.FAILED
@ -68,10 +67,9 @@ class TestAppImportApi:
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
assert status == 202
assert response["status"] == ImportStatus.PENDING
@ -88,10 +86,9 @@ class TestAppImportApi:
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
update_access.assert_called_once_with("app-123", "private")
assert status == 200
@ -107,7 +104,6 @@ class TestAppImportApi:
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
@ -115,7 +111,7 @@ class TestAppImportApi:
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
fake_session.commit.assert_called_once_with()
fake_session.rollback.assert_not_called()
@ -132,7 +128,6 @@ class TestAppImportApi:
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
fake_session = MagicMock()
fake_session.__enter__.return_value = fake_session
@ -140,7 +135,7 @@ class TestAppImportApi:
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
fake_session.rollback.assert_called_once_with()
fake_session.commit.assert_not_called()
@ -162,10 +157,9 @@ class TestAppImportConfirmApi:
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
assert status == 400
assert response["status"] == ImportStatus.FAILED

View File

@ -36,19 +36,12 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(
read_at = datetime(2026, 2, 9, 0, 0, 0)
with (
patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, tenant.id),
autospec=True,
),
patch(
"controllers.console.app.conversation.naive_utc_now",
return_value=read_at,
autospec=True,
),
with patch(
"controllers.console.app.conversation.naive_utc_now",
return_value=read_at,
autospec=True,
):
loaded = _get_conversation(app, conversation.id)
loaded = _get_conversation(account, app, conversation.id)
db_session_with_containers.refresh(conversation)
@ -64,10 +57,5 @@ def test_get_conversation_raises_not_found_for_missing_conversation(
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
with patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, tenant.id),
autospec=True,
):
with pytest.raises(NotFound):
_get_conversation(app, "00000000-0000-0000-0000-000000000000")
with pytest.raises(NotFound):
_get_conversation(account, app, "00000000-0000-0000-0000-000000000000")

View File

@ -52,16 +52,12 @@ class TestRagPipelineImportApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
response, status = method(api, user)
assert status == 200
assert response == {"status": "success"}
@ -82,16 +78,12 @@ class TestRagPipelineImportApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
response, status = method(api, user)
assert status == 400
assert response == {"status": "failed"}
@ -112,16 +104,12 @@ class TestRagPipelineImportApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
response, status = method(api, user)
assert status == 202
assert response == {"status": "pending"}
@ -146,16 +134,12 @@ class TestRagPipelineImportConfirmApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
response, status = method(api, user, "import-1")
assert status == 200
assert response == {"ok": True}
@ -174,16 +158,12 @@ class TestRagPipelineImportConfirmApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
response, status = method(api, user, "import-1")
assert status == 400
assert response == {"ok": False}

View File

@ -104,10 +104,9 @@ class TestMemberCancelInviteApiWithContainers:
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(members_module.TenantService, "remove_member_from_tenant") as mock_remove_member,
):
result, status = method(api, member.id)
result, status = method(api, current_user, member.id)
assert status == 200
assert result["result"] == "success"
@ -123,12 +122,9 @@ class TestMemberCancelInviteApiWithContainers:
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with flask_app_with_containers.test_request_context("/"):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
method(api, current_user, str(uuid4()))
def test_cancel_cannot_operate_self(self, flask_app_with_containers, db_session_with_containers):
api = MemberCancelInviteApi()
@ -139,14 +135,13 @@ class TestMemberCancelInviteApiWithContainers:
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
result, status = method(api, member.id)
result, status = method(api, current_user, member.id)
assert status == 400
assert result["code"] == "cannot-operate-self"
@ -160,14 +155,13 @@ class TestMemberCancelInviteApiWithContainers:
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
result, status = method(api, member.id)
result, status = method(api, current_user, member.id)
assert status == 403
assert result["code"] == "forbidden"
@ -181,14 +175,13 @@ class TestMemberCancelInviteApiWithContainers:
with (
flask_app_with_containers.test_request_context("/"),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(
members_module.TenantService,
"remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
result, status = method(api, member.id)
result, status = method(api, current_user, member.id)
assert status == 404
assert result["code"] == "member-not-found"
@ -207,11 +200,8 @@ class TestMemberUpdateRoleApiWithContainers:
role=TenantAccountRole.EDITOR,
)
with (
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
result = method(api, member.id)
with flask_app_with_containers.test_request_context("/", json={"role": "normal"}):
result = method(api, current_user, member.id)
if isinstance(result, tuple):
result = result[0]
@ -227,12 +217,9 @@ class TestMemberUpdateRoleApiWithContainers:
factory = WorkspaceMembersIntegrationFactory
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
with (
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with flask_app_with_containers.test_request_context("/", json={"role": "normal"}):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
method(api, current_user, str(uuid4()))
class TestOwnerTransferApiWithContainers:
@ -244,12 +231,9 @@ class TestOwnerTransferApiWithContainers:
member = factory.create_account(db_session_with_containers, email_prefix="member")
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with flask_app_with_containers.test_request_context("/", json={"token": token}):
with pytest.raises(MemberNotInTenantError):
method(api, member.id)
method(api, current_user, member.id)
def test_member_not_found(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
@ -258,12 +242,9 @@ class TestOwnerTransferApiWithContainers:
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
token = factory.create_owner_transfer_token(current_user)
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
):
with flask_app_with_containers.test_request_context("/", json={"token": token}):
with pytest.raises(HTTPException):
method(api, str(uuid4()))
method(api, current_user, str(uuid4()))
def test_transfer_success(self, flask_app_with_containers, db_session_with_containers):
api = OwnerTransfer()
@ -280,11 +261,10 @@ class TestOwnerTransferApiWithContainers:
with (
flask_app_with_containers.test_request_context("/", json={"token": token}),
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
patch.object(members_module.AccountService, "send_new_owner_transfer_notify_email") as mock_new_owner_email,
patch.object(members_module.AccountService, "send_old_owner_transfer_notify_email") as mock_old_owner_email,
):
result = method(api, member.id)
result = method(api, current_user, member.id)
assert result["result"] == "success"
assert (

View File

@ -61,10 +61,9 @@ class TestAppImportApi:
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
session.rollback.assert_called_once_with()
session.commit.assert_not_called()
@ -83,10 +82,9 @@ class TestAppImportApi:
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
@ -107,10 +105,9 @@ class TestAppImportApi:
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
response, status = method(SimpleNamespace(id="u1"))
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
@ -135,10 +132,9 @@ class TestAppImportConfirmApi:
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
session.rollback.assert_called_once_with()
session.commit.assert_not_called()

View File

@ -29,7 +29,6 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch:
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
@ -41,7 +40,7 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch:
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(account, app_model=SimpleNamespace(id="app-1"))
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
@ -51,7 +50,6 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(
conversation_module,
"parse_time_range",
@ -64,7 +62,7 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes
query_string={"start": "bad"},
):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
method(account, app_model=SimpleNamespace(id="app-1"))
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -72,7 +70,6 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
@ -84,7 +81,7 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
response = method(account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
@ -95,10 +92,9 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
session = MagicMock()
session.scalar.return_value = conversation
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
result = conversation_module._get_conversation(_make_account(), SimpleNamespace(id="app-1"), "c1")
assert result is conversation
session.execute.assert_called_once()
@ -110,18 +106,16 @@ def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPat
session = MagicMock()
session.scalar.return_value = None
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
with pytest.raises(NotFound):
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
conversation_module._get_conversation(_make_account(), SimpleNamespace(id="app-1"), "missing")
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationDetailApi()
method = _unwrap(api.delete)
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(
conversation_module.ConversationService,
"delete",
@ -129,4 +123,4 @@ def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.Monke
)
with pytest.raises(NotFound):
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
method(_make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1")

View File

@ -38,11 +38,6 @@ def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
@ -60,7 +55,7 @@ def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPat
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
@ -74,7 +69,7 @@ def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.Monk
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
@ -88,7 +83,7 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 1
@ -106,7 +101,7 @@ def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyP
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
@ -128,17 +123,12 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes
raise ValueError("Invalid time range")
_install_db(monkeypatch, [])
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -154,7 +144,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 3
@ -168,7 +158,7 @@ def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPat
_install_db(monkeypatch, [])
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": []}
@ -179,11 +169,6 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_db(monkeypatch, rows)
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
@ -192,7 +177,7 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
@ -209,7 +194,7 @@ def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.Monk
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 2

View File

@ -41,6 +41,15 @@ def encode_code(code: str) -> str:
return base64.b64encode(code.encode("utf-8")).decode()
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class TestLoginApi:
"""Test cases for the LoginApi endpoint."""
@ -486,13 +495,9 @@ class TestLogoutApi:
account.email = "test@example.com"
return account
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
def test_successful_logout(
self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account
):
def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account):
"""
Test successful logout flow.
@ -502,23 +507,18 @@ class TestLogoutApi:
- All authentication cookies are cleared
- Success response is returned
"""
# Arrange
mock_current_account.return_value = (mock_account, MagicMock())
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = logout_api.post()
response = _unwrap(logout_api.post)(mock_account)
# Assert
mock_service_logout.assert_called_once_with(account=mock_account)
mock_logout_user.assert_called_once()
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.current_account_with_tenant")
@patch("controllers.console.auth.login.flask_login")
def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask):
def test_logout_anonymous_user(self, mock_flask_login, app: Flask):
"""
Test logout for anonymous (not logged in) user.
@ -532,12 +532,11 @@ class TestLogoutApi:
anonymous_user = MagicMock()
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
anonymous_user.__class__ = mock_flask_login.AnonymousUserMixin
mock_current_account.return_value = (anonymous_user, None)
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = logout_api.post()
response = _unwrap(logout_api.post)(anonymous_user)
# Assert
assert response.json["result"] == "success"

View File

@ -96,10 +96,6 @@ class TestDatasetMetadataCreateApi:
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
@ -120,7 +116,7 @@ class TestDatasetMetadataCreateApi:
return_value={"id": "m1", "type": "string", "name": "author"},
),
):
result, status = method(api, dataset_id)
result, status = method(api, current_user, dataset_id)
assert status == 201
assert result["type"] == "string"
@ -143,10 +139,6 @@ class TestDatasetMetadataCreateApi:
new_callable=PropertyMock,
return_value=valid_payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
@ -159,7 +151,7 @@ class TestDatasetMetadataCreateApi:
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
method(api, current_user, dataset_id)
class TestDatasetMetadataGetApi:
@ -220,10 +212,6 @@ class TestDatasetMetadataApi:
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
@ -239,7 +227,7 @@ class TestDatasetMetadataApi:
return_value={"id": "m1", "type": "string", "name": "updated-name"},
),
):
result, status = method(api, dataset_id, metadata_id)
result, status = method(api, current_user, dataset_id, metadata_id)
assert status == 200
assert result["type"] == "string"
@ -251,10 +239,6 @@ class TestDatasetMetadataApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
@ -269,7 +253,7 @@ class TestDatasetMetadataApi:
"delete_metadata",
),
):
result, status = method(api, dataset_id, metadata_id)
result, status = method(api, current_user, dataset_id, metadata_id)
assert status == 204
assert result == ""
@ -307,10 +291,6 @@ class TestDatasetMetadataBuiltInFieldActionApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
@ -325,7 +305,7 @@ class TestDatasetMetadataBuiltInFieldActionApi:
"enable_built_in_field",
),
):
result, status = method(api, dataset_id, "enable")
result, status = method(api, current_user, dataset_id, "enable")
assert status == 204
assert result == ""
@ -346,10 +326,6 @@ class TestDocumentMetadataEditApi:
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
@ -369,7 +345,7 @@ class TestDocumentMetadataEditApi:
"update_documents_metadata",
),
):
result, status = method(api, dataset_id)
result, status = method(api, current_user, dataset_id)
assert status == 204
assert result == ""

View File

@ -1,6 +1,6 @@
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import patch
from unittest.mock import ANY, patch
import pytest
from flask import Flask, g
@ -39,24 +39,14 @@ class TestMemberInviteEmailApi:
@patch("controllers.console.workspace.members.FeatureService.get_features")
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
@patch("controllers.console.workspace.members.current_account_with_tenant")
@patch("controllers.console.wraps.db")
@patch("libs.login.check_csrf_token", return_value=None)
def test_invite_normalizes_emails(
self,
mock_csrf,
mock_db,
mock_current_account,
mock_invite_member,
mock_get_features,
app: Flask,
):
def test_invite_normalizes_emails(self, mock_csrf, mock_db, mock_invite_member, mock_get_features, app: Flask):
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"
tenant = SimpleNamespace(id="tenant-1", name="Test Tenant")
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
mock_current_account.return_value = (inviter, tenant.id)
with (
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"),
@ -84,5 +74,5 @@ class TestMemberInviteEmailApi:
assert call_args.kwargs["email"] == "user@example.com"
assert call_args.kwargs["language"] == "en-US"
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
assert call_args.kwargs["inviter"] == inviter
mock_csrf.assert_called_once()
assert call_args.kwargs["inviter"] == account
mock_csrf.assert_called_once_with(ANY, account.id)

View File

@ -49,10 +49,9 @@ class TestMemberListApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
):
result, status = method(api)
result, status = method(api, user)
assert status == 200
assert len(result["accounts"]) == 1
@ -65,10 +64,9 @@ class TestMemberListApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
method(api, user)
class TestMemberInviteEmailApi:
@ -96,7 +94,6 @@ class TestMemberInviteEmailApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
@ -104,7 +101,7 @@ class TestMemberInviteEmailApi:
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
result, status = method(api)
result, status = method(api, user)
assert status == 201
assert result["result"] == "success"
@ -127,14 +124,13 @@ class TestMemberInviteEmailApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
method(api, user)
def test_invite_billing_limit_exceeded(self, app: Flask):
api = MemberInviteEmailApi()
@ -155,7 +151,6 @@ class TestMemberInviteEmailApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=2),
patch("controllers.console.workspace.members._count_current_members", return_value=9),
@ -163,7 +158,7 @@ class TestMemberInviteEmailApi:
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", True),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
method(api, user)
def test_invite_already_member(self, app: Flask):
api = MemberInviteEmailApi()
@ -183,7 +178,6 @@ class TestMemberInviteEmailApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=0),
patch(
@ -194,7 +188,7 @@ class TestMemberInviteEmailApi:
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
result, status = method(api)
result, status = method(api, user)
assert result["invitation_results"][0]["status"] == "success"
@ -208,7 +202,7 @@ class TestMemberInviteEmailApi:
}
with app.test_request_context("/", json=payload):
result, status = method(api)
result, status = method(api, MagicMock())
assert status == 400
assert result["code"] == "invalid-role"
@ -231,7 +225,6 @@ class TestMemberInviteEmailApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
patch(
@ -242,7 +235,7 @@ class TestMemberInviteEmailApi:
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
):
result, _ = method(api)
result, _ = method(api, user)
assert result["invitation_results"][0]["status"] == "failed"
@ -255,7 +248,7 @@ class TestMemberUpdateRoleApi:
payload = {"role": "invalid-role"}
with app.test_request_context("/", json=payload):
result, status = method(api, "id")
result, status = method(api, MagicMock(), "id")
assert status == 400
@ -278,12 +271,11 @@ class TestDatasetOperatorMemberListApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
),
):
result, status = method(api)
result, status = method(api, user)
assert status == 200
assert len(result["accounts"]) == 1
@ -296,10 +288,9 @@ class TestDatasetOperatorMemberListApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
method(api, user)
class TestSendOwnerTransferEmailApi:
@ -316,13 +307,12 @@ class TestSendOwnerTransferEmailApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
),
):
result = method(api)
result = method(api, user)
assert result["result"] == "success"
@ -338,7 +328,7 @@ class TestSendOwnerTransferEmailApi:
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
):
with pytest.raises(EmailSendIpLimitError):
method(api)
method(api, MagicMock())
def test_send_not_owner(self, app: Flask):
api = SendOwnerTransferEmailApi()
@ -351,11 +341,10 @@ class TestSendOwnerTransferEmailApi:
app.test_request_context("/", json={}),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
):
with pytest.raises(NotOwnerError):
method(api)
method(api, user)
class TestOwnerTransferCheckApi:
@ -370,7 +359,6 @@ class TestOwnerTransferCheckApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
@ -382,7 +370,7 @@ class TestOwnerTransferCheckApi:
),
):
with pytest.raises(EmailCodeError):
method(api)
method(api, user)
def test_rate_limited(self, app: Flask):
api = OwnerTransferCheckApi()
@ -395,7 +383,6 @@ class TestOwnerTransferCheckApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
@ -403,7 +390,7 @@ class TestOwnerTransferCheckApi:
),
):
with pytest.raises(OwnerTransferLimitError):
method(api)
method(api, user)
def test_invalid_token(self, app: Flask):
api = OwnerTransferCheckApi()
@ -416,7 +403,6 @@ class TestOwnerTransferCheckApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
@ -425,7 +411,7 @@ class TestOwnerTransferCheckApi:
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api)
method(api, user)
def test_invalid_email(self, app: Flask):
api = OwnerTransferCheckApi()
@ -438,7 +424,6 @@ class TestOwnerTransferCheckApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
@ -450,7 +435,7 @@ class TestOwnerTransferCheckApi:
),
):
with pytest.raises(InvalidEmailError):
method(api)
method(api, user)
class TestOwnerTransferApi:
@ -465,11 +450,10 @@ class TestOwnerTransferApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
):
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
method(api, user, "1")
def test_invalid_token(self, app: Flask):
api = OwnerTransfer()
@ -482,9 +466,8 @@ class TestOwnerTransferApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api, "2")
method(api, user, "2")