mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
refactor: inject current user into user-only controllers (#36754)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
This commit is contained in:
parent
d8571ce965
commit
0a3005701f
@ -9,9 +9,11 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -48,9 +50,9 @@ class AppImportApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
|
||||
@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
@ -31,8 +36,9 @@ from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -93,8 +99,8 @@ class CompletionConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -207,8 +214,8 @@ class ChatConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
subquery = (
|
||||
@ -318,10 +325,11 @@ class ChatConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@ -335,8 +343,8 @@ class ChatConversationDetailApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def _get_conversation(current_user: Account, app_model, conversation_id):
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
@ -25,6 +25,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
@ -43,7 +44,8 @@ from fields.conversation_fields import (
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args.message_id)
|
||||
@ -337,8 +338,8 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
try:
|
||||
|
||||
@ -14,12 +14,14 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Site
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
@ -85,9 +87,9 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
@ -134,8 +136,8 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
|
||||
@ -8,13 +8,14 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import AppMode
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
@ -48,9 +49,8 @@ class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -109,9 +109,8 @@ class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -169,9 +168,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -230,9 +228,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -294,9 +291,8 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
@ -374,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
@ -444,9 +439,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -505,8 +499,8 @@ class TokensPerSecondStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
|
||||
@ -6,10 +6,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -32,11 +32,11 @@ from controllers.console.wraps import (
|
||||
decrypt_password_field,
|
||||
email_password_login_enabled,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
@ -46,6 +46,7 @@ from libs.token import (
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
@ -172,9 +173,8 @@ class LoginApi(Resource):
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
|
||||
@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
|
||||
@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource):
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import (
|
||||
@ -17,7 +18,8 @@ from fields.rag_pipeline_fields import (
|
||||
pipeline_import_check_dependencies_fields,
|
||||
pipeline_import_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@ -62,9 +64,9 @@ class RagPipelineImportApi(Resource):
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Use a plain Session so that caught exceptions inside the service
|
||||
@ -105,9 +107,8 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
def post(self, import_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
account = current_user
|
||||
|
||||
@ -25,12 +25,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_allow_transfer_owner,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
@ -136,8 +137,8 @@ class MemberListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
@ -154,7 +155,8 @@ class MemberInviteEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberInvitePayload.model_validate(payload)
|
||||
|
||||
@ -163,7 +165,6 @@ class MemberInviteEmailApi(Resource):
|
||||
interface_language = args.language
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -223,8 +224,8 @@ class MemberCancelInviteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, member_id: UUID):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.get(Account, str(member_id))
|
||||
@ -256,14 +257,14 @@ class MemberUpdateRoleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id: UUID):
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, member_id: UUID):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberRoleUpdatePayload.model_validate(payload)
|
||||
new_role = args.role
|
||||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not _is_role_enabled(new_role, current_user.current_tenant.id):
|
||||
@ -297,8 +298,8 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
@ -317,13 +318,13 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferEmailPayload.model_validate(payload)
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -355,11 +356,11 @@ class OwnerTransferCheckApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferCheckPayload.model_validate(payload)
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
@ -399,12 +400,12 @@ class OwnerTransfer(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, member_id: UUID):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferPayload.model_validate(payload)
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
||||
@ -400,15 +400,10 @@ class TestSiteEndpoints:
|
||||
"session",
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/", json={"title": "My Site"}):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["title"] == "My Site"
|
||||
@ -439,15 +434,10 @@ class TestSiteEndpoints:
|
||||
MagicMock(scalar=lambda *_args, **_kwargs: site, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["access_token"] == "code"
|
||||
@ -586,11 +576,6 @@ class TestWorkflowStatisticEndpoints:
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
@ -601,7 +586,7 @@ class TestWorkflowStatisticEndpoints:
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
|
||||
|
||||
@ -614,11 +599,6 @@ class TestWorkflowStatisticEndpoints:
|
||||
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
@ -629,7 +609,7 @@ class TestWorkflowStatisticEndpoints:
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
|
||||
@ -50,10 +50,9 @@ class TestAppImportApi:
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
@ -68,10 +67,9 @@ class TestAppImportApi:
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
@ -88,10 +86,9 @@ class TestAppImportApi:
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
@ -107,7 +104,6 @@ class TestAppImportApi:
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
@ -115,7 +111,7 @@ class TestAppImportApi:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
fake_session.commit.assert_called_once_with()
|
||||
fake_session.rollback.assert_not_called()
|
||||
@ -132,7 +128,6 @@ class TestAppImportApi:
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
@ -140,7 +135,7 @@ class TestAppImportApi:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
fake_session.rollback.assert_called_once_with()
|
||||
fake_session.commit.assert_not_called()
|
||||
@ -162,10 +157,9 @@ class TestAppImportConfirmApi:
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
@ -36,19 +36,12 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged(
|
||||
|
||||
read_at = datetime(2026, 2, 9, 0, 0, 0)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.app.conversation.current_account_with_tenant",
|
||||
return_value=(account, tenant.id),
|
||||
autospec=True,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.app.conversation.naive_utc_now",
|
||||
return_value=read_at,
|
||||
autospec=True,
|
||||
),
|
||||
with patch(
|
||||
"controllers.console.app.conversation.naive_utc_now",
|
||||
return_value=read_at,
|
||||
autospec=True,
|
||||
):
|
||||
loaded = _get_conversation(app, conversation.id)
|
||||
loaded = _get_conversation(account, app, conversation.id)
|
||||
|
||||
db_session_with_containers.refresh(conversation)
|
||||
|
||||
@ -64,10 +57,5 @@ def test_get_conversation_raises_not_found_for_missing_conversation(
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT)
|
||||
|
||||
with patch(
|
||||
"controllers.console.app.conversation.current_account_with_tenant",
|
||||
return_value=(account, tenant.id),
|
||||
autospec=True,
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
_get_conversation(app, "00000000-0000-0000-0000-000000000000")
|
||||
with pytest.raises(NotFound):
|
||||
_get_conversation(account, app, "00000000-0000-0000-0000-000000000000")
|
||||
|
||||
@ -52,16 +52,12 @@ class TestRagPipelineImportApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, user)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"status": "success"}
|
||||
@ -82,16 +78,12 @@ class TestRagPipelineImportApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, user)
|
||||
|
||||
assert status == 400
|
||||
assert response == {"status": "failed"}
|
||||
@ -112,16 +104,12 @@ class TestRagPipelineImportApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, user)
|
||||
|
||||
assert status == 202
|
||||
assert response == {"status": "pending"}
|
||||
@ -146,16 +134,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
response, status = method(api, user, "import-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
@ -174,16 +158,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
response, status = method(api, user, "import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response == {"ok": False}
|
||||
|
||||
@ -104,10 +104,9 @@ class TestMemberCancelInviteApiWithContainers:
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(members_module.TenantService, "remove_member_from_tenant") as mock_remove_member,
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
result, status = method(api, current_user, member.id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
@ -123,12 +122,9 @@ class TestMemberCancelInviteApiWithContainers:
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with flask_app_with_containers.test_request_context("/"):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, str(uuid4()))
|
||||
method(api, current_user, str(uuid4()))
|
||||
|
||||
def test_cancel_cannot_operate_self(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = MemberCancelInviteApi()
|
||||
@ -139,14 +135,13 @@ class TestMemberCancelInviteApiWithContainers:
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(
|
||||
members_module.TenantService,
|
||||
"remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
result, status = method(api, current_user, member.id)
|
||||
|
||||
assert status == 400
|
||||
assert result["code"] == "cannot-operate-self"
|
||||
@ -160,14 +155,13 @@ class TestMemberCancelInviteApiWithContainers:
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(
|
||||
members_module.TenantService,
|
||||
"remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
result, status = method(api, current_user, member.id)
|
||||
|
||||
assert status == 403
|
||||
assert result["code"] == "forbidden"
|
||||
@ -181,14 +175,13 @@ class TestMemberCancelInviteApiWithContainers:
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/"),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(
|
||||
members_module.TenantService,
|
||||
"remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
result, status = method(api, member.id)
|
||||
result, status = method(api, current_user, member.id)
|
||||
|
||||
assert status == 404
|
||||
assert result["code"] == "member-not-found"
|
||||
@ -207,11 +200,8 @@ class TestMemberUpdateRoleApiWithContainers:
|
||||
role=TenantAccountRole.EDITOR,
|
||||
)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
result = method(api, member.id)
|
||||
with flask_app_with_containers.test_request_context("/", json={"role": "normal"}):
|
||||
result = method(api, current_user, member.id)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
@ -227,12 +217,9 @@ class TestMemberUpdateRoleApiWithContainers:
|
||||
factory = WorkspaceMembersIntegrationFactory
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"role": "normal"}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with flask_app_with_containers.test_request_context("/", json={"role": "normal"}):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, str(uuid4()))
|
||||
method(api, current_user, str(uuid4()))
|
||||
|
||||
|
||||
class TestOwnerTransferApiWithContainers:
|
||||
@ -244,12 +231,9 @@ class TestOwnerTransferApiWithContainers:
|
||||
member = factory.create_account(db_session_with_containers, email_prefix="member")
|
||||
token = factory.create_owner_transfer_token(current_user)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"token": token}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with flask_app_with_containers.test_request_context("/", json={"token": token}):
|
||||
with pytest.raises(MemberNotInTenantError):
|
||||
method(api, member.id)
|
||||
method(api, current_user, member.id)
|
||||
|
||||
def test_member_not_found(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = OwnerTransfer()
|
||||
@ -258,12 +242,9 @@ class TestOwnerTransferApiWithContainers:
|
||||
tenant, current_user = factory.create_owner_workspace(db_session_with_containers)
|
||||
token = factory.create_owner_transfer_token(current_user)
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"token": token}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
):
|
||||
with flask_app_with_containers.test_request_context("/", json={"token": token}):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, str(uuid4()))
|
||||
method(api, current_user, str(uuid4()))
|
||||
|
||||
def test_transfer_success(self, flask_app_with_containers, db_session_with_containers):
|
||||
api = OwnerTransfer()
|
||||
@ -280,11 +261,10 @@ class TestOwnerTransferApiWithContainers:
|
||||
|
||||
with (
|
||||
flask_app_with_containers.test_request_context("/", json={"token": token}),
|
||||
patch.object(members_module, "current_account_with_tenant", return_value=(current_user, tenant.id)),
|
||||
patch.object(members_module.AccountService, "send_new_owner_transfer_notify_email") as mock_new_owner_email,
|
||||
patch.object(members_module.AccountService, "send_old_owner_transfer_notify_email") as mock_old_owner_email,
|
||||
):
|
||||
result = method(api, member.id)
|
||||
result = method(api, current_user, member.id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert (
|
||||
|
||||
@ -61,10 +61,9 @@ class TestAppImportApi:
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
@ -83,10 +82,9 @@ class TestAppImportApi:
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
@ -107,10 +105,9 @@ class TestAppImportApi:
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
@ -135,10 +132,9 @@ class TestAppImportConfirmApi:
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
|
||||
@ -29,7 +29,6 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch:
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
@ -41,7 +40,7 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch:
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(account, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
@ -51,7 +50,6 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module,
|
||||
"parse_time_range",
|
||||
@ -64,7 +62,7 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes
|
||||
query_string={"start": "bad"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
method(account, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -72,7 +70,6 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
@ -84,7 +81,7 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
response = method(account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
|
||||
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
@ -95,10 +92,9 @@ def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = conversation
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
|
||||
result = conversation_module._get_conversation(_make_account(), SimpleNamespace(id="app-1"), "c1")
|
||||
|
||||
assert result is conversation
|
||||
session.execute.assert_called_once()
|
||||
@ -110,18 +106,16 @@ def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPat
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
|
||||
conversation_module._get_conversation(_make_account(), SimpleNamespace(id="app-1"), "missing")
|
||||
|
||||
|
||||
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationDetailApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
@ -129,4 +123,4 @@ def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.Monke
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
method(_make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
|
||||
@ -38,11 +38,6 @@ def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
|
||||
|
||||
|
||||
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
@ -60,7 +55,7 @@ def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPat
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
|
||||
|
||||
@ -74,7 +69,7 @@ def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.Monk
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
@ -88,7 +83,7 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 1
|
||||
@ -106,7 +101,7 @@ def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyP
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
|
||||
|
||||
@ -128,17 +123,12 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes
|
||||
raise ValueError("Invalid time range")
|
||||
|
||||
_install_db(monkeypatch, [])
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -154,7 +144,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 3
|
||||
@ -168,7 +158,7 @@ def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPat
|
||||
_install_db(monkeypatch, [])
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": []}
|
||||
|
||||
@ -179,11 +169,6 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_db(monkeypatch, rows)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
@ -192,7 +177,7 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
@ -209,7 +194,7 @@ def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.Monk
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
@ -41,6 +41,15 @@ def encode_code(code: str) -> str:
|
||||
return base64.b64encode(code.encode("utf-8")).decode()
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
"""Test cases for the LoginApi endpoint."""
|
||||
|
||||
@ -486,13 +495,9 @@ class TestLogoutApi:
|
||||
account.email = "test@example.com"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.login.AccountService.logout")
|
||||
@patch("controllers.console.auth.login.flask_login.logout_user")
|
||||
def test_successful_logout(
|
||||
self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account
|
||||
):
|
||||
def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account):
|
||||
"""
|
||||
Test successful logout flow.
|
||||
|
||||
@ -502,23 +507,18 @@ class TestLogoutApi:
|
||||
- All authentication cookies are cleared
|
||||
- Success response is returned
|
||||
"""
|
||||
# Arrange
|
||||
mock_current_account.return_value = (mock_account, MagicMock())
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = logout_api.post()
|
||||
response = _unwrap(logout_api.post)(mock_account)
|
||||
|
||||
# Assert
|
||||
mock_service_logout.assert_called_once_with(account=mock_account)
|
||||
mock_logout_user.assert_called_once()
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.current_account_with_tenant")
|
||||
@patch("controllers.console.auth.login.flask_login")
|
||||
def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask):
|
||||
def test_logout_anonymous_user(self, mock_flask_login, app: Flask):
|
||||
"""
|
||||
Test logout for anonymous (not logged in) user.
|
||||
|
||||
@ -532,12 +532,11 @@ class TestLogoutApi:
|
||||
anonymous_user = MagicMock()
|
||||
mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {})
|
||||
anonymous_user.__class__ = mock_flask_login.AnonymousUserMixin
|
||||
mock_current_account.return_value = (anonymous_user, None)
|
||||
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = logout_api.post()
|
||||
response = _unwrap(logout_api.post)(anonymous_user)
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@ -96,10 +96,6 @@ class TestDatasetMetadataCreateApi:
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
@ -120,7 +116,7 @@ class TestDatasetMetadataCreateApi:
|
||||
return_value={"id": "m1", "type": "string", "name": "author"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
result, status = method(api, current_user, dataset_id)
|
||||
|
||||
assert status == 201
|
||||
assert result["type"] == "string"
|
||||
@ -143,10 +139,6 @@ class TestDatasetMetadataCreateApi:
|
||||
new_callable=PropertyMock,
|
||||
return_value=valid_payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
@ -159,7 +151,7 @@ class TestDatasetMetadataCreateApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
method(api, current_user, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataGetApi:
|
||||
@ -220,10 +212,6 @@ class TestDatasetMetadataApi:
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
@ -239,7 +227,7 @@ class TestDatasetMetadataApi:
|
||||
return_value={"id": "m1", "type": "string", "name": "updated-name"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
result, status = method(api, current_user, dataset_id, metadata_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["type"] == "string"
|
||||
@ -251,10 +239,6 @@ class TestDatasetMetadataApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
@ -269,7 +253,7 @@ class TestDatasetMetadataApi:
|
||||
"delete_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
result, status = method(api, current_user, dataset_id, metadata_id)
|
||||
|
||||
assert status == 204
|
||||
assert result == ""
|
||||
@ -307,10 +291,6 @@ class TestDatasetMetadataBuiltInFieldActionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
@ -325,7 +305,7 @@ class TestDatasetMetadataBuiltInFieldActionApi:
|
||||
"enable_built_in_field",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, "enable")
|
||||
result, status = method(api, current_user, dataset_id, "enable")
|
||||
|
||||
assert status == 204
|
||||
assert result == ""
|
||||
@ -346,10 +326,6 @@ class TestDocumentMetadataEditApi:
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
@ -369,7 +345,7 @@ class TestDocumentMetadataEditApi:
|
||||
"update_documents_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
result, status = method(api, current_user, dataset_id)
|
||||
|
||||
assert status == 204
|
||||
assert result == ""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
@ -39,24 +39,14 @@ class TestMemberInviteEmailApi:
|
||||
|
||||
@patch("controllers.console.workspace.members.FeatureService.get_features")
|
||||
@patch("controllers.console.workspace.members.RegisterService.invite_new_member")
|
||||
@patch("controllers.console.workspace.members.current_account_with_tenant")
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
def test_invite_normalizes_emails(
|
||||
self,
|
||||
mock_csrf,
|
||||
mock_db,
|
||||
mock_current_account,
|
||||
mock_invite_member,
|
||||
mock_get_features,
|
||||
app: Flask,
|
||||
):
|
||||
def test_invite_normalizes_emails(self, mock_csrf, mock_db, mock_invite_member, mock_get_features, app: Flask):
|
||||
mock_get_features.return_value = _build_feature_flags()
|
||||
mock_invite_member.return_value = "token-abc"
|
||||
|
||||
tenant = SimpleNamespace(id="tenant-1", name="Test Tenant")
|
||||
inviter = SimpleNamespace(email="Owner@Example.com", current_tenant=tenant, status="active")
|
||||
mock_current_account.return_value = (inviter, tenant.id)
|
||||
|
||||
with (
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "https://console.example.com"),
|
||||
@ -84,5 +74,5 @@ class TestMemberInviteEmailApi:
|
||||
assert call_args.kwargs["email"] == "user@example.com"
|
||||
assert call_args.kwargs["language"] == "en-US"
|
||||
assert call_args.kwargs["role"] == TenantAccountRole.EDITOR
|
||||
assert call_args.kwargs["inviter"] == inviter
|
||||
mock_csrf.assert_called_once()
|
||||
assert call_args.kwargs["inviter"] == account
|
||||
mock_csrf.assert_called_once_with(ANY, account.id)
|
||||
|
||||
@ -49,10 +49,9 @@ class TestMemberListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
@ -65,10 +64,9 @@ class TestMemberListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestMemberInviteEmailApi:
|
||||
@ -96,7 +94,6 @@ class TestMemberInviteEmailApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
|
||||
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
|
||||
@ -104,7 +101,7 @@ class TestMemberInviteEmailApi:
|
||||
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
assert status == 201
|
||||
assert result["result"] == "success"
|
||||
@ -127,14 +124,13 @@ class TestMemberInviteEmailApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
|
||||
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
|
||||
):
|
||||
with pytest.raises(WorkspaceMembersLimitExceeded):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_invite_billing_limit_exceeded(self, app: Flask):
|
||||
api = MemberInviteEmailApi()
|
||||
@ -155,7 +151,6 @@ class TestMemberInviteEmailApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members._count_new_member_invites", return_value=2),
|
||||
patch("controllers.console.workspace.members._count_current_members", return_value=9),
|
||||
@ -163,7 +158,7 @@ class TestMemberInviteEmailApi:
|
||||
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", True),
|
||||
):
|
||||
with pytest.raises(WorkspaceMembersLimitExceeded):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_invite_already_member(self, app: Flask):
|
||||
api = MemberInviteEmailApi()
|
||||
@ -183,7 +178,6 @@ class TestMemberInviteEmailApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members._count_new_member_invites", return_value=0),
|
||||
patch(
|
||||
@ -194,7 +188,7 @@ class TestMemberInviteEmailApi:
|
||||
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "success"
|
||||
|
||||
@ -208,7 +202,7 @@ class TestMemberInviteEmailApi:
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api)
|
||||
result, status = method(api, MagicMock())
|
||||
|
||||
assert status == 400
|
||||
assert result["code"] == "invalid-role"
|
||||
@ -231,7 +225,6 @@ class TestMemberInviteEmailApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members._count_new_member_invites", return_value=1),
|
||||
patch(
|
||||
@ -242,7 +235,7 @@ class TestMemberInviteEmailApi:
|
||||
patch("controllers.console.workspace.members.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.members.dify_config.BILLING_ENABLED", False),
|
||||
):
|
||||
result, _ = method(api)
|
||||
result, _ = method(api, user)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "failed"
|
||||
|
||||
@ -255,7 +248,7 @@ class TestMemberUpdateRoleApi:
|
||||
payload = {"role": "invalid-role"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api, "id")
|
||||
result, status = method(api, MagicMock(), "id")
|
||||
|
||||
assert status == 400
|
||||
|
||||
@ -278,12 +271,11 @@ class TestDatasetOperatorMemberListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
@ -296,10 +288,9 @@ class TestDatasetOperatorMemberListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestSendOwnerTransferEmailApi:
|
||||
@ -316,13 +307,12 @@ class TestSendOwnerTransferEmailApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, user)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -338,7 +328,7 @@ class TestSendOwnerTransferEmailApi:
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
|
||||
):
|
||||
with pytest.raises(EmailSendIpLimitError):
|
||||
method(api)
|
||||
method(api, MagicMock())
|
||||
|
||||
def test_send_not_owner(self, app: Flask):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
@ -351,11 +341,10 @@ class TestSendOwnerTransferEmailApi:
|
||||
app.test_request_context("/", json={}),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
|
||||
):
|
||||
with pytest.raises(NotOwnerError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestOwnerTransferCheckApi:
|
||||
@ -370,7 +359,6 @@ class TestOwnerTransferCheckApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
@ -382,7 +370,7 @@ class TestOwnerTransferCheckApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_rate_limited(self, app: Flask):
|
||||
api = OwnerTransferCheckApi()
|
||||
@ -395,7 +383,6 @@ class TestOwnerTransferCheckApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
@ -403,7 +390,7 @@ class TestOwnerTransferCheckApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(OwnerTransferLimitError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_invalid_token(self, app: Flask):
|
||||
api = OwnerTransferCheckApi()
|
||||
@ -416,7 +403,6 @@ class TestOwnerTransferCheckApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
@ -425,7 +411,7 @@ class TestOwnerTransferCheckApi:
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_invalid_email(self, app: Flask):
|
||||
api = OwnerTransferCheckApi()
|
||||
@ -438,7 +424,6 @@ class TestOwnerTransferCheckApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
@ -450,7 +435,7 @@ class TestOwnerTransferCheckApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidEmailError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestOwnerTransferApi:
|
||||
@ -465,11 +450,10 @@ class TestOwnerTransferApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
):
|
||||
with pytest.raises(CannotTransferOwnerToSelfError):
|
||||
method(api, "1")
|
||||
method(api, user, "1")
|
||||
|
||||
def test_invalid_token(self, app: Flask):
|
||||
api = OwnerTransfer()
|
||||
@ -482,9 +466,8 @@ class TestOwnerTransferApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api, "2")
|
||||
method(api, user, "2")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user