mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 21:11:16 +08:00
chore: make AccountService.load_user use passed session (#37764)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f665bcac95
commit
2112115962
@ -133,8 +133,9 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
password=new_password,
|
||||
language=language,
|
||||
create_workspace_required=False,
|
||||
session=db.session,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name, session=db.session)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@ -19,6 +19,7 @@ from controllers.console.wraps import (
|
||||
rbac_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
Annotation,
|
||||
@ -388,7 +389,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id))
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
update_args, str(app_id), str(annotation_id), db.session
|
||||
)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@ -398,7 +401,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT)
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
def delete(self, app_id: UUID, annotation_id: UUID):
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id), db.session)
|
||||
return "", 204
|
||||
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.console.wraps import (
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
@ -489,7 +490,7 @@ class WorkflowCommentMentionUsersApi(Resource):
|
||||
current_tenant = current_user.current_tenant # need the tenant object here
|
||||
if current_tenant is None:
|
||||
raise ValueError("current tenant is required")
|
||||
members = TenantService.get_tenant_members(current_tenant)
|
||||
members = TenantService.get_tenant_members(current_tenant, session=db.session)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -89,7 +89,9 @@ class ActivateCheckApi(Resource):
|
||||
workspaceId = args.workspace_id
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(
|
||||
workspaceId, args.email, token, session=db.session
|
||||
)
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
@ -137,7 +139,9 @@ class ActivateApi(Resource):
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
normalized_request_email = args.email.lower() if args.email else None
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(
|
||||
args.workspace_id, args.email, args.token, session=db.session
|
||||
)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
@ -184,6 +188,6 @@ class ActivateApi(Resource):
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
TenantService.switch_tenant(account, tenant.id, session=db.session)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -187,7 +187,7 @@ class EmailRegisterResetApi(Resource):
|
||||
timezone=args.timezone,
|
||||
language=args.language,
|
||||
)
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
@ -206,6 +206,7 @@ class EmailRegisterResetApi(Resource):
|
||||
password=password,
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=timezone,
|
||||
session=db.session,
|
||||
)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
@ -198,10 +198,10 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
# Create workspace if needed
|
||||
if (
|
||||
not TenantService.get_join_tenants(account)
|
||||
not TenantService.get_join_tenants(account, session=db.session)
|
||||
and FeatureService.get_system_features().is_allow_create_workspace
|
||||
):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
@ -119,7 +119,9 @@ class LoginApi(Resource):
|
||||
invite_token = args.invite_token
|
||||
invitation_data: InvitationDetailDict | None = None
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(
|
||||
None, request_email, invite_token, session=db.session
|
||||
)
|
||||
if invitation_data is None:
|
||||
invite_token = None
|
||||
|
||||
@ -145,7 +147,7 @@ class LoginApi(Resource):
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
tenants = TenantService.get_join_tenants(account, session=db.session)
|
||||
if len(tenants) == 0:
|
||||
system_features = FeatureService.get_system_features()
|
||||
|
||||
@ -157,7 +159,7 @@ class LoginApi(Resource):
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
@ -291,7 +293,7 @@ class EmailCodeLoginApi(Resource):
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
tenants = TenantService.get_join_tenants(account, session=db.session)
|
||||
if not tenants:
|
||||
workspaces = FeatureService.get_system_features().license.workspaces
|
||||
if not workspaces.is_available():
|
||||
@ -299,7 +301,7 @@ class EmailCodeLoginApi(Resource):
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
else:
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
|
||||
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
|
||||
account.current_tenant = new_tenant
|
||||
tenant_was_created.send(new_tenant)
|
||||
@ -311,6 +313,7 @@ class EmailCodeLoginApi(Resource):
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=args.timezone,
|
||||
session=db.session,
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
@ -319,7 +322,7 @@ class EmailCodeLoginApi(Resource):
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
token_pair = AccountService.login(account, session=db.session, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
@ -343,7 +346,7 @@ class RefreshTokenApi(Resource):
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||
new_token_pair = AccountService.refresh_token(refresh_token, session=db.session)
|
||||
|
||||
# Create response with new cookies
|
||||
response = make_response({"result": "success"})
|
||||
@ -358,22 +361,22 @@ class RefreshTokenApi(Resource):
|
||||
|
||||
|
||||
def _get_account_with_case_fallback(email: str):
|
||||
account = AccountService.get_user_through_email(email)
|
||||
account = AccountService.get_user_through_email(email, session=db.session)
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return AccountService.get_user_through_email(email.lower())
|
||||
return AccountService.get_user_through_email(email.lower(), session=db.session)
|
||||
|
||||
|
||||
def _authenticate_account_with_case_fallback(
|
||||
original_email: str, normalized_email: str, password: str, invite_token: str | None
|
||||
):
|
||||
try:
|
||||
return AccountService.authenticate(original_email, password, invite_token)
|
||||
return AccountService.authenticate(original_email, password, invite_token, session=db.session)
|
||||
except services.errors.account.AccountPasswordError:
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
return AccountService.authenticate(normalized_email, password, invite_token, session=db.session)
|
||||
|
||||
|
||||
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
|
||||
@ -195,7 +195,7 @@ class OAuthCallback(Resource):
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, session=db.session)
|
||||
except Unauthorized:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
@ -206,6 +206,7 @@ class OAuthCallback(Resource):
|
||||
|
||||
token_pair = AccountService.login(
|
||||
account=account,
|
||||
session=db.session,
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -240,12 +241,12 @@ def _generate_account(
|
||||
oauth_new_user = False
|
||||
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
tenants = TenantService.get_join_tenants(account, session=db.session)
|
||||
if not tenants:
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise WorkSpaceNotAllowedCreateError()
|
||||
else:
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
|
||||
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
|
||||
account.current_tenant = new_tenant
|
||||
tenant_was_created.send(new_tenant)
|
||||
@ -272,9 +273,10 @@ def _generate_account(
|
||||
provider=provider,
|
||||
language=interface_language,
|
||||
timezone=timezone,
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
AccountService.link_account_integrate(provider, user_info.id, account, session=db.session)
|
||||
|
||||
return account, oauth_new_user
|
||||
|
||||
@ -175,7 +175,7 @@ class InstalledAppsListApi(Resource):
|
||||
|
||||
if current_user.current_tenant is None:
|
||||
raise ValueError("current_user.current_tenant must not be None")
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant, session=db.session)
|
||||
installed_app_list: list[dict[str, Any]] = []
|
||||
for installed_app, app_model in installed_apps:
|
||||
installed_app_list.append(
|
||||
|
||||
@ -50,7 +50,7 @@ def get_init_status() -> InitStatusResponse:
|
||||
@only_edition_self_hosted
|
||||
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
|
||||
"""Validate initialization password."""
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
tenant_count = TenantService.get_tenant_count(session=db.session)
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
tenant_count = TenantService.get_tenant_count(session=db.session)
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
@ -94,6 +94,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
password=payload.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=payload.language,
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
return SetupResponse(result="success")
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
@ -43,7 +44,7 @@ def socket_connect(sid, environ, auth):
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
user = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
@ -328,7 +328,7 @@ class AccountNameApi(Resource):
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
updated_account = AccountService.update_account(current_user, session=db.session, name=args.name)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -374,7 +374,7 @@ class AccountAvatarApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
updated_account = AccountService.update_account(current_user, session=db.session, avatar=args.avatar)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -391,7 +391,9 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
updated_account = AccountService.update_account(
|
||||
current_user, session=db.session, interface_language=args.interface_language
|
||||
)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -408,7 +410,9 @@ class AccountInterfaceThemeApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
updated_account = AccountService.update_account(
|
||||
current_user, session=db.session, interface_theme=args.interface_theme
|
||||
)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -425,7 +429,7 @@ class AccountTimezoneApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
updated_account = AccountService.update_account(current_user, session=db.session, timezone=args.timezone)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -443,7 +447,8 @@ class AccountPasswordApi(Resource):
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
try:
|
||||
AccountService.update_account_password(current_user, args.password, args.new_password)
|
||||
assert args.password is not None
|
||||
AccountService.update_account_password(current_user, args.password, args.new_password, session=db.session)
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
@ -731,7 +736,7 @@ class ChangeEmailResetApi(Resource):
|
||||
if AccountService.is_account_in_freeze(normalized_new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
if not AccountService.check_email_unique(normalized_new_email, session=db.session):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
@ -755,7 +760,9 @@ class ChangeEmailResetApi(Resource):
|
||||
# legitimately verified token.
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
updated_account = AccountService.update_account_email(
|
||||
current_user, email=normalized_new_email, session=db.session
|
||||
)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=normalized_new_email,
|
||||
@ -775,6 +782,6 @@ class CheckEmailUnique(Resource):
|
||||
normalized_email = args.email.lower()
|
||||
if AccountService.is_account_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(normalized_email):
|
||||
if not AccountService.check_email_unique(normalized_email, session=db.session):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
||||
@ -186,7 +186,7 @@ class MemberListApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant, session=db.session)
|
||||
if dify_config.RBAC_ENABLED:
|
||||
member_ids = [member.id for member in members]
|
||||
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
|
||||
@ -273,6 +273,7 @@ class MemberInviteEmailApi(Resource):
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
session=db.session,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
invitation_results.append(
|
||||
@ -317,7 +318,9 @@ class MemberCancelInviteApi(Resource):
|
||||
abort(404)
|
||||
else:
|
||||
try:
|
||||
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
|
||||
TenantService.remove_member_from_tenant(
|
||||
current_user.current_tenant, member, current_user, session=db.session
|
||||
)
|
||||
except services.errors.account.CannotOperateSelfError as e:
|
||||
return {"code": "cannot-operate-self", "message": str(e)}, 400
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@ -360,7 +363,9 @@ class MemberUpdateRoleApi(Resource):
|
||||
|
||||
try:
|
||||
assert member is not None, "Member not found"
|
||||
TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user)
|
||||
TenantService.update_member_role(
|
||||
current_user.current_tenant, member, new_role, current_user, session=db.session
|
||||
)
|
||||
except services.errors.account.CannotOperateSelfError as e:
|
||||
return {"code": "cannot-operate-self", "message": str(e)}, 400
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@ -387,7 +392,7 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant, session=db.session)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@ -413,7 +418,7 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
|
||||
raise NotOwnerError()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
@ -448,7 +453,7 @@ class OwnerTransferCheckApi(Resource):
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
|
||||
raise NotOwnerError()
|
||||
|
||||
user_email = current_user.email
|
||||
@ -494,7 +499,7 @@ class OwnerTransfer(Resource):
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
|
||||
raise NotOwnerError()
|
||||
|
||||
if current_user.id == str(member_id):
|
||||
@ -516,12 +521,14 @@ class OwnerTransfer(Resource):
|
||||
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_member(member, current_user.current_tenant):
|
||||
if not TenantService.is_member(member, current_user.current_tenant, session=db.session):
|
||||
raise MemberNotInTenantError()
|
||||
|
||||
try:
|
||||
assert member is not None, "Member not found"
|
||||
TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user)
|
||||
TenantService.update_member_role(
|
||||
current_user.current_tenant, member, "owner", current_user, session=db.session
|
||||
)
|
||||
|
||||
AccountService.send_new_owner_transfer_notify_email(
|
||||
account=member,
|
||||
|
||||
@ -325,10 +325,10 @@ class TenantApi(Resource):
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenants = TenantService.get_join_tenants(current_user, session=db.session)
|
||||
# if there is any tenant, switch to the first one
|
||||
if len(tenants) > 0:
|
||||
TenantService.switch_tenant(current_user, tenants[0].id)
|
||||
TenantService.switch_tenant(current_user, tenants[0].id, session=db.session)
|
||||
tenant = tenants[0]
|
||||
# else, raise Unauthorized
|
||||
else:
|
||||
@ -351,7 +351,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
|
||||
# check if tenant_id is valid, 403 if not
|
||||
try:
|
||||
TenantService.switch_tenant(current_user, args.tenant_id)
|
||||
TenantService.switch_tenant(current_user, args.tenant_id, session=db.session)
|
||||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class EnterpriseWorkspace(Resource):
|
||||
if account is None:
|
||||
return {"message": "owner account not found."}, 404
|
||||
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
@ -84,7 +84,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
def post(self):
|
||||
args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
|
||||
@ -128,7 +128,7 @@ class WorkspaceSwitchApi(Resource):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
try:
|
||||
TenantService.switch_tenant(account, workspace_id)
|
||||
TenantService.switch_tenant(account, workspace_id, session=db.session)
|
||||
except AccountNotLinkTenantError:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
@ -152,7 +152,7 @@ class WorkspaceMembersApi(Resource):
|
||||
@accepts(query=MemberListQuery)
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery):
|
||||
tenant = _load_tenant(workspace_id)
|
||||
members = TenantService.get_tenant_members(tenant)
|
||||
members = TenantService.get_tenant_members(tenant, session=db.session)
|
||||
total = len(members)
|
||||
start = (query.page - 1) * query.limit
|
||||
page_items = members[start : start + query.limit]
|
||||
@ -184,6 +184,7 @@ class WorkspaceMembersApi(Resource):
|
||||
language=None,
|
||||
role=body.role,
|
||||
inviter=inviter,
|
||||
session=db.session,
|
||||
)
|
||||
except AccountAlreadyInTenantError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
@ -232,7 +233,7 @@ class WorkspaceMemberApi(Resource):
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.remove_member_from_tenant(tenant, member, operator)
|
||||
TenantService.remove_member_from_tenant(tenant, member, operator, session=db.session)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
@ -266,7 +267,7 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.update_member_role(tenant, member, body.role, operator)
|
||||
TenantService.update_member_role(tenant, member, body.role, operator, session=db.session)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.common.schema import query_params_from_model, register_response
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from fields.base import ResponseModel
|
||||
@ -281,7 +282,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
"""Update an existing annotation."""
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
update_args, app_model.id, str(annotation_id), db.session
|
||||
)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -310,5 +313,5 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: UUID):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id), db.session)
|
||||
return "", 204
|
||||
|
||||
@ -226,7 +226,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
all_multimodal_documents.append(file_document)
|
||||
doc.attachments = attachments
|
||||
else:
|
||||
account = AccountService.load_user(document.created_by)
|
||||
account = AccountService.load_user(document.created_by, db.session)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||
|
||||
@ -291,7 +291,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
attachments.append(file_document)
|
||||
doc.attachments = attachments
|
||||
else:
|
||||
account = AccountService.load_user(document.created_by)
|
||||
account = AccountService.load_user(document.created_by, db.session)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||
|
||||
@ -84,7 +84,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
if not user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "openapi":
|
||||
# Account-branch device-flow approval routes (approve / deny /
|
||||
@ -103,7 +103,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
source = decoded.get("token_source")
|
||||
if source or not user_id:
|
||||
return None
|
||||
return AccountService.load_logged_in_account(account_id=user_id)
|
||||
return AccountService.load_logged_in_account(account_id=user_id, session=db.session)
|
||||
elif request.blueprint == "web":
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||
|
||||
@ -16,7 +16,7 @@ def valid_password(password):
|
||||
raise ValueError("Password must contain letters and numbers, and the length must be at least 8 characters.")
|
||||
|
||||
|
||||
def hash_password(password_str, salt_byte):
|
||||
def hash_password(password_str: str, salt_byte: bytes):
|
||||
dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
|
||||
return binascii.hexlify(dk)
|
||||
|
||||
|
||||
@ -1,3 +1,10 @@
|
||||
"""Account, workspace, and invitation services.
|
||||
|
||||
Database access in this module is caller-scoped: methods that read or mutate ORM state accept an explicit
|
||||
``session`` so controllers, tasks, and tests can control transaction lifetime and avoid hidden Flask-scoped session
|
||||
usage inside service logic.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -235,7 +242,7 @@ class AccountService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _refresh_account_last_active(account: Account) -> None:
|
||||
def _refresh_account_last_active(account: Account, session: scoped_session | Session) -> None:
|
||||
now = naive_utc_now()
|
||||
refresh_before = now - ACCOUNT_LAST_ACTIVE_REFRESH_INTERVAL
|
||||
|
||||
@ -245,12 +252,12 @@ class AccountService:
|
||||
if not AccountService._should_refresh_account_last_active(account.id):
|
||||
return
|
||||
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(Account)
|
||||
.where(Account.id == account.id, Account.last_active_at < refresh_before)
|
||||
.values(last_active_at=now, updated_at=func.current_timestamp())
|
||||
)
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _store_refresh_token(refresh_token: str, account_id: str):
|
||||
@ -295,20 +302,20 @@ class AccountService:
|
||||
side-effects (current-tenant assignment, commit) are unwanted.
|
||||
|
||||
``session`` is injected by the caller so this service stays free
|
||||
of the Flask-scoped ``db.session`` import.
|
||||
of a Flask-scoped session import.
|
||||
"""
|
||||
return session.get(Account, account_id)
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = db.session.get(Account, user_id)
|
||||
def load_user(user_id: str, session: scoped_session | Session) -> None | Account:
|
||||
account = session.get(Account, user_id)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = db.session.scalar(
|
||||
current_tenant = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
|
||||
.limit(1)
|
||||
@ -316,7 +323,7 @@ class AccountService:
|
||||
if current_tenant:
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
else:
|
||||
available_ta = db.session.scalar(
|
||||
available_ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
@ -328,13 +335,13 @@ class AccountService:
|
||||
account.set_tenant_id(available_ta.tenant_id)
|
||||
available_ta.current = True
|
||||
available_ta.last_opened_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
AccountService._refresh_account_last_active(account)
|
||||
AccountService._refresh_account_last_active(account, session)
|
||||
# NOTE: make sure account is accessible outside of a db session
|
||||
# This ensures that it will work correctly after upgrading to Flask version 3.1.2
|
||||
db.session.refresh(account)
|
||||
db.session.close()
|
||||
session.refresh(account)
|
||||
session.close()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@ -352,10 +359,12 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
||||
def authenticate(
|
||||
email: str, password: str, invite_token: str | None = None, *, session: scoped_session | Session
|
||||
) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
account = session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
@ -378,12 +387,14 @@ class AccountService:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
def update_account_password(
|
||||
account: Account, password: str, new_password: str, *, session: scoped_session | Session
|
||||
):
|
||||
"""update account password"""
|
||||
if account.password and not compare_password(password, account.password, account.password_salt):
|
||||
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
||||
@ -400,8 +411,8 @@ class AccountService:
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.add(account)
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@ -413,6 +424,8 @@ class AccountService:
|
||||
interface_theme: str = "light",
|
||||
is_setup: bool | None = False,
|
||||
timezone: str | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Account:
|
||||
"""Create an account, preferring explicit user timezone over language-derived defaults."""
|
||||
if not FeatureService.get_system_features().is_allow_register and not is_setup:
|
||||
@ -458,13 +471,19 @@ class AccountService:
|
||||
timezone=resolved_timezone,
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.add(account)
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_account_and_tenant(
|
||||
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
|
||||
email: str,
|
||||
name: str,
|
||||
interface_language: str,
|
||||
password: str | None = None,
|
||||
timezone: str | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Account:
|
||||
"""Create an account and owner workspace."""
|
||||
account = AccountService.create_account(
|
||||
@ -473,10 +492,11 @@ class AccountService:
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
timezone=timezone,
|
||||
session=session,
|
||||
)
|
||||
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, session=session)
|
||||
except Exception:
|
||||
# Enterprise-only side-effect should run independently from personal workspace creation.
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
@ -536,11 +556,11 @@ class AccountService:
|
||||
delete_account_task.delay(account.id)
|
||||
|
||||
@staticmethod
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account):
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account, *, session: scoped_session | Session):
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: AccountIntegrate | None = db.session.scalar(
|
||||
account_integrate: AccountIntegrate | None = session.scalar(
|
||||
select(AccountIntegrate)
|
||||
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
|
||||
.limit(1)
|
||||
@ -556,62 +576,62 @@ class AccountService:
|
||||
account_integrate = AccountIntegrate(
|
||||
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
|
||||
)
|
||||
db.session.add(account_integrate)
|
||||
session.add(account_integrate)
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
logger.info("Account %s linked %s account %s.", account.id, provider, open_id)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
|
||||
raise LinkAccountIntegrateError("Failed to link account.") from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account):
|
||||
def close_account(account: Account, *, session: scoped_session | Session):
|
||||
"""Close account"""
|
||||
account.status = AccountStatus.CLOSED
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_account(account, **kwargs):
|
||||
def update_account(account: Account, *, session: scoped_session | Session, **kwargs):
|
||||
"""Update account fields"""
|
||||
account = db.session.merge(account)
|
||||
account = session.merge(account)
|
||||
for field, value in kwargs.items():
|
||||
if hasattr(account, field):
|
||||
setattr(account, field, value)
|
||||
else:
|
||||
raise AttributeError(f"Invalid field: {field}")
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_account_email(account: Account, email: str) -> Account:
|
||||
def update_account_email(account: Account, email: str, session: scoped_session | Session) -> Account:
|
||||
"""Update account email"""
|
||||
account.email = email
|
||||
account_integrate = db.session.scalar(
|
||||
account_integrate = session.scalar(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
|
||||
)
|
||||
if account_integrate:
|
||||
db.session.delete(account_integrate)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.delete(account_integrate)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_login_info(account: Account, *, ip_address: str):
|
||||
def update_login_info(account: Account, session: scoped_session | Session, *, ip_address: str):
|
||||
"""Update last login time and ip"""
|
||||
account.last_login_at = naive_utc_now()
|
||||
account.last_login_ip = ip_address
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
|
||||
def login(account: Account, *, session: scoped_session | Session, ip_address: str | None = None) -> TokenPair:
|
||||
if ip_address:
|
||||
AccountService.update_login_info(account=account, ip_address=ip_address)
|
||||
AccountService.update_login_info(account=account, session=session, ip_address=ip_address)
|
||||
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
access_token = AccountService.get_account_jwt_token(account=account)
|
||||
refresh_token = _generate_refresh_token()
|
||||
@ -628,13 +648,13 @@ class AccountService:
|
||||
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
||||
|
||||
@staticmethod
|
||||
def refresh_token(refresh_token: str) -> TokenPair:
|
||||
def refresh_token(refresh_token: str, *, session: scoped_session | Session) -> TokenPair:
|
||||
# Verify the refresh token
|
||||
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
|
||||
if not account_id:
|
||||
raise ValueError("Invalid refresh token")
|
||||
|
||||
account = AccountService.load_user(account_id.decode("utf-8"))
|
||||
account = AccountService.load_user(account_id.decode("utf-8"), session)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
|
||||
@ -649,8 +669,8 @@ class AccountService:
|
||||
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token)
|
||||
|
||||
@staticmethod
|
||||
def load_logged_in_account(*, account_id: str):
|
||||
return AccountService.load_user(account_id)
|
||||
def load_logged_in_account(*, account_id: str, session: scoped_session | Session):
|
||||
return AccountService.load_user(account_id, session)
|
||||
|
||||
@classmethod
|
||||
def send_reset_password_email(
|
||||
@ -1002,7 +1022,7 @@ class AccountService:
|
||||
TokenManager.revoke_token(token, "email_code_login")
|
||||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
def get_user_through_email(cls, email: str, *, session: scoped_session | Session):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
@ -1011,7 +1031,7 @@ class AccountService:
|
||||
)
|
||||
)
|
||||
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
account = session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@ -1210,13 +1230,19 @@ class AccountService:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_email_unique(email: str) -> bool:
|
||||
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||
def check_email_unique(email: str, *, session: scoped_session | Session) -> bool:
|
||||
return session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||
|
||||
|
||||
class TenantService:
|
||||
@staticmethod
|
||||
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
|
||||
def create_tenant(
|
||||
name: str,
|
||||
is_setup: bool | None = False,
|
||||
is_from_dashboard: bool | None = False,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Tenant:
|
||||
"""Create tenant"""
|
||||
if (
|
||||
not FeatureService.get_system_features().is_allow_create_workspace
|
||||
@ -1228,8 +1254,8 @@ class TenantService:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
tenant = Tenant(name=name)
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
session.add(tenant)
|
||||
session.commit()
|
||||
|
||||
for category in TenantPluginAutoUpgradeStrategy.PluginCategory:
|
||||
plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy(
|
||||
@ -1241,11 +1267,11 @@ class TenantService:
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
db.session.add(plugin_upgrade_strategy)
|
||||
db.session.commit()
|
||||
session.add(plugin_upgrade_strategy)
|
||||
session.commit()
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
@ -1254,9 +1280,11 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
|
||||
def create_owner_tenant_if_not_exist(
|
||||
account: Account, name: str | None = None, is_setup: bool | None = False, *, session: scoped_session | Session
|
||||
):
|
||||
"""Check if user have a workspace or not"""
|
||||
available_ta = db.session.scalar(
|
||||
available_ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
@ -1275,10 +1303,10 @@ class TenantService:
|
||||
raise WorkspacesLimitExceededError()
|
||||
|
||||
if name:
|
||||
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
|
||||
tenant = TenantService.create_tenant(name=name, is_setup=is_setup, session=session)
|
||||
else:
|
||||
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup, session=session)
|
||||
TenantService.create_tenant_member(tenant, account, session, role="owner")
|
||||
if dify_config.RBAC_ENABLED:
|
||||
owner_role_id = AccountService._resolve_legacy_role_id(str(tenant.id), account.id, TenantAccountRole.OWNER)
|
||||
RBACService.MemberRoles.replace(
|
||||
@ -1288,16 +1316,16 @@ class TenantService:
|
||||
role_ids=[owner_role_id],
|
||||
)
|
||||
account.current_tenant = tenant
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
@staticmethod
|
||||
def create_tenant_member(
|
||||
tenant: Tenant, account: Account, session: scoped_session, role: str = "normal"
|
||||
tenant: Tenant, account: Account, session: scoped_session | Session, role: str = "normal"
|
||||
) -> TenantAccountJoin:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountRole.OWNER:
|
||||
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
|
||||
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=session):
|
||||
logger.error("Tenant %s has already an owner.", tenant.id)
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
@ -1318,10 +1346,10 @@ class TenantService:
|
||||
return ta
|
||||
|
||||
@staticmethod
|
||||
def get_join_tenants(account: Account) -> list[Tenant]:
|
||||
def get_join_tenants(account: Account, *, session: scoped_session | Session) -> list[Tenant]:
|
||||
"""Get account join tenants"""
|
||||
return list(
|
||||
db.session.scalars(
|
||||
session.scalars(
|
||||
select(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
@ -1340,7 +1368,7 @@ class TenantService:
|
||||
membership + pick the default workspace.
|
||||
|
||||
``session`` is injected by the caller so this service stays free
|
||||
of the Flask-scoped ``db.session`` import.
|
||||
of a Flask-scoped session import.
|
||||
|
||||
No tenant-status filter: parity with the legacy controller query
|
||||
(the openapi identity endpoint listed all joined tenants).
|
||||
@ -1413,7 +1441,7 @@ class TenantService:
|
||||
bearers (no account) collapse to the non-member path. Mirrors the
|
||||
session-injection style of :meth:`account_belongs_to_tenant` rather
|
||||
than :meth:`get_user_role`, which loads full ``Account``/``Tenant``
|
||||
objects against the Flask-scoped ``db.session``.
|
||||
objects against the Flask-scoped session.
|
||||
"""
|
||||
if not account_id:
|
||||
return None
|
||||
@ -1479,13 +1507,13 @@ class TenantService:
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def get_current_tenant_by_account(account: Account):
|
||||
def get_current_tenant_by_account(account: Account, *, session: scoped_session | Session):
|
||||
"""Get tenant by account and add the role"""
|
||||
tenant = account.current_tenant
|
||||
if not tenant:
|
||||
raise TenantNotFoundError("Tenant not found.")
|
||||
|
||||
ta = db.session.scalar(
|
||||
ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1497,14 +1525,14 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def switch_tenant(account: Account, tenant_id: str | None = None):
|
||||
def switch_tenant(account: Account, tenant_id: str | None = None, *, session: scoped_session | Session):
|
||||
"""Switch the current workspace for the account"""
|
||||
|
||||
# Ensure tenant_id is provided
|
||||
if tenant_id is None:
|
||||
raise ValueError("Tenant ID must be provided.")
|
||||
|
||||
tenant_account_join = db.session.scalar(
|
||||
tenant_account_join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
@ -1518,7 +1546,7 @@ class TenantService:
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
|
||||
.values(current=False)
|
||||
@ -1527,10 +1555,10 @@ class TenantService:
|
||||
tenant_account_join.last_opened_at = naive_utc_now()
|
||||
# Set the current tenant for the account
|
||||
account.set_tenant_id(tenant_account_join.tenant_id)
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_members(tenant: Tenant) -> list[Account]:
|
||||
def get_tenant_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
|
||||
"""Get tenant members"""
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
@ -1542,14 +1570,14 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in db.session.execute(stmt):
|
||||
for account, role in session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
return updated_accounts
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
|
||||
def get_dataset_operator_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
|
||||
"""Get dataset admin members"""
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
@ -1562,20 +1590,20 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in db.session.execute(stmt):
|
||||
for account, role in session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
return updated_accounts
|
||||
|
||||
@staticmethod
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountRole]) -> bool:
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountRole], *, session: scoped_session | Session) -> bool:
|
||||
"""Check if user has any of the given roles for a tenant"""
|
||||
if not all(isinstance(role, TenantAccountRole) for role in roles):
|
||||
raise ValueError("all roles must be TenantAccountRole")
|
||||
|
||||
return (
|
||||
db.session.scalar(
|
||||
session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
@ -1587,9 +1615,11 @@ class TenantService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
|
||||
def get_user_role(
|
||||
account: Account, tenant: Tenant, *, session: scoped_session | Session
|
||||
) -> TenantAccountRole | None:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = db.session.scalar(
|
||||
join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1597,12 +1627,14 @@ class TenantService:
|
||||
return TenantAccountRole(join.role) if join else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_count() -> int:
|
||||
def get_tenant_count(*, session: scoped_session | Session) -> int:
|
||||
"""Get tenant count"""
|
||||
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
|
||||
return cast(int, session.scalar(select(func.count(Tenant.id))))
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
|
||||
def check_member_permission(
|
||||
tenant: Tenant, operator: Account, member: Account | None, action: str, *, session: scoped_session | Session
|
||||
):
|
||||
"""Check member permission"""
|
||||
if action not in {"add", "remove", "update"}:
|
||||
raise InvalidActionError("Invalid action.")
|
||||
@ -1636,7 +1668,7 @@ class TenantService:
|
||||
"update": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
}
|
||||
|
||||
ta_operator = db.session.scalar(
|
||||
ta_operator = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
|
||||
.limit(1)
|
||||
@ -1646,7 +1678,7 @@ class TenantService:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
if action == "remove" and ta_operator.role == TenantAccountRole.ADMIN and member:
|
||||
ta_member = db.session.scalar(
|
||||
ta_member = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
|
||||
.limit(1)
|
||||
@ -1655,7 +1687,9 @@ class TenantService:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
|
||||
def remove_member_from_tenant(
|
||||
tenant: Tenant, account: Account, operator: Account, *, session: scoped_session | Session
|
||||
):
|
||||
"""Remove member from tenant.
|
||||
|
||||
Apps and datasets maintained by the removed member are reassigned to
|
||||
@ -1667,9 +1701,9 @@ class TenantService:
|
||||
if operator.id == account.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove")
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove", session=session)
|
||||
|
||||
ta = db.session.scalar(
|
||||
ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1686,7 +1720,7 @@ class TenantService:
|
||||
if dify_config.RBAC_ENABLED:
|
||||
owner_id = AccountService.get_rbac_workspace_owner_account_id(str(tenant.id), str(operator.id))
|
||||
else:
|
||||
owner_id = db.session.scalar(
|
||||
owner_id = session.scalar(
|
||||
select(TenantAccountJoin.account_id)
|
||||
.where(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
@ -1697,7 +1731,7 @@ class TenantService:
|
||||
if owner_id is None:
|
||||
raise ValueError(f"Workspace owner not found for tenant {tenant.id}.")
|
||||
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(App)
|
||||
.where(
|
||||
App.tenant_id == tenant.id,
|
||||
@ -1705,7 +1739,7 @@ class TenantService:
|
||||
)
|
||||
.values(maintainer=owner_id)
|
||||
)
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(Dataset)
|
||||
.where(
|
||||
Dataset.tenant_id == tenant.id,
|
||||
@ -1713,23 +1747,23 @@ class TenantService:
|
||||
)
|
||||
.values(maintainer=owner_id)
|
||||
)
|
||||
db.session.delete(ta)
|
||||
session.delete(ta)
|
||||
|
||||
# Clean up orphaned pending accounts (invited but never activated)
|
||||
should_delete_account = False
|
||||
if account.status == AccountStatus.PENDING:
|
||||
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
|
||||
remaining_joins = (
|
||||
db.session.scalar(
|
||||
session.scalar(
|
||||
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if remaining_joins == 0:
|
||||
db.session.delete(account)
|
||||
session.delete(account)
|
||||
should_delete_account = True
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
if should_delete_account:
|
||||
logger.info(
|
||||
@ -1755,12 +1789,14 @@ class TenantService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
def update_member_role(
|
||||
tenant: Tenant, member: Account, new_role: str, operator: Account, *, session: scoped_session | Session
|
||||
):
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
TenantService.check_member_permission(tenant, operator, member, "update", session=session)
|
||||
new_tenant_role = TenantAccountRole(new_role)
|
||||
|
||||
target_member_join = db.session.scalar(
|
||||
target_member_join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
|
||||
.limit(1)
|
||||
@ -1769,7 +1805,7 @@ class TenantService:
|
||||
if not target_member_join:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
operator_role = TenantService.get_user_role(operator, tenant)
|
||||
operator_role = TenantService.get_user_role(operator, tenant, session=session)
|
||||
target_role = TenantAccountRole(target_member_join.role)
|
||||
if operator_role == TenantAccountRole.ADMIN and (TenantAccountRole.OWNER in {target_role, new_tenant_role}):
|
||||
raise NoPermissionError("No permission to update member.")
|
||||
@ -1779,7 +1815,7 @@ class TenantService:
|
||||
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = db.session.scalar(
|
||||
current_owner_join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
@ -1815,7 +1851,7 @@ class TenantService:
|
||||
)
|
||||
else:
|
||||
target_member_join.role = new_tenant_role
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str):
|
||||
@ -1824,13 +1860,13 @@ class TenantService:
|
||||
return tenant.custom_config_dict
|
||||
|
||||
@staticmethod
|
||||
def is_owner(account: Account, tenant: Tenant) -> bool:
|
||||
return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER
|
||||
def is_owner(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
|
||||
return TenantService.get_user_role(account, tenant, session=session) == TenantAccountRole.OWNER
|
||||
|
||||
@staticmethod
|
||||
def is_member(account: Account, tenant: Tenant) -> bool:
|
||||
def is_member(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
|
||||
"""Check if the account is a member of the tenant"""
|
||||
return TenantService.get_user_role(account, tenant) is not None
|
||||
return TenantService.get_user_role(account, tenant, session=session) is not None
|
||||
|
||||
|
||||
class RegisterService:
|
||||
@ -1839,7 +1875,16 @@ class RegisterService:
|
||||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
|
||||
def setup(
|
||||
cls,
|
||||
email: str,
|
||||
name: str,
|
||||
password: str,
|
||||
ip_address: str,
|
||||
language: str | None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
):
|
||||
"""
|
||||
Setup dify
|
||||
|
||||
@ -1856,22 +1901,23 @@ class RegisterService:
|
||||
interface_language=get_valid_language(language),
|
||||
password=password,
|
||||
is_setup=True,
|
||||
session=session,
|
||||
)
|
||||
|
||||
account.last_login_ip = ip_address
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True, session=session)
|
||||
|
||||
dify_setup = DifySetup(version=dify_config.project.version)
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
session.add(dify_setup)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
db.session.execute(delete(DifySetup))
|
||||
db.session.execute(delete(TenantAccountJoin))
|
||||
db.session.execute(delete(Account))
|
||||
db.session.execute(delete(Tenant))
|
||||
db.session.commit()
|
||||
session.execute(delete(DifySetup))
|
||||
session.execute(delete(TenantAccountJoin))
|
||||
session.execute(delete(Account))
|
||||
session.execute(delete(Tenant))
|
||||
session.commit()
|
||||
|
||||
logger.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
raise ValueError(f"Setup failed: {e}")
|
||||
@ -1889,9 +1935,11 @@ class RegisterService:
|
||||
is_setup: bool | None = False,
|
||||
create_workspace_required: bool | None = True,
|
||||
timezone: str | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Account:
|
||||
"""Register account"""
|
||||
db.session.begin_nested()
|
||||
session.begin_nested()
|
||||
try:
|
||||
interface_language = get_valid_language(language)
|
||||
account = AccountService.create_account(
|
||||
@ -1901,12 +1949,13 @@ class RegisterService:
|
||||
password=password,
|
||||
is_setup=is_setup,
|
||||
timezone=timezone,
|
||||
session=session,
|
||||
)
|
||||
account.status = status or AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
if open_id is not None and provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
AccountService.link_account_integrate(provider, open_id, account, session=session)
|
||||
|
||||
if (
|
||||
FeatureService.get_system_features().is_allow_create_workspace
|
||||
@ -1914,27 +1963,27 @@ class RegisterService:
|
||||
and FeatureService.get_system_features().license.workspaces.is_available()
|
||||
):
|
||||
try:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=session)
|
||||
TenantService.create_tenant_member(tenant, account, session, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
except Exception:
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
raise
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
db.session.rollback()
|
||||
session.rollback()
|
||||
logger.exception("Register failed")
|
||||
raise AccountRegisterError("Workspace is not allowed to create.")
|
||||
except AccountRegisterError as are:
|
||||
db.session.rollback()
|
||||
session.rollback()
|
||||
logger.exception("Register failed")
|
||||
raise are
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
session.rollback()
|
||||
logger.exception("Register failed")
|
||||
raise AccountRegisterError(f"Registration failed: {e}") from e
|
||||
|
||||
@ -1942,7 +1991,14 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None
|
||||
cls,
|
||||
tenant: Tenant,
|
||||
email: str,
|
||||
language: str | None,
|
||||
role: str = "normal",
|
||||
inviter: Account | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> str:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
@ -1960,7 +2016,7 @@ class RegisterService:
|
||||
|
||||
requires_setup = False
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add", session=session)
|
||||
name = normalized_email.split("@")[0]
|
||||
|
||||
account = cls.register(
|
||||
@ -1969,13 +2025,14 @@ class RegisterService:
|
||||
language=language,
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
session=session,
|
||||
)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
|
||||
TenantService.switch_tenant(account, tenant.id, session=session)
|
||||
requires_setup = True
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = db.session.scalar(
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add", session=session)
|
||||
ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1983,7 +2040,7 @@ class RegisterService:
|
||||
requires_setup = account.status == AccountStatus.PENDING
|
||||
|
||||
if not ta and (account.status == AccountStatus.PENDING or dify_config.RBAC_ENABLED):
|
||||
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
|
||||
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
|
||||
|
||||
# Support resend invitation email when the account is pending status
|
||||
if account.status != AccountStatus.PENDING:
|
||||
@ -2052,20 +2109,20 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def get_invitation_if_token_valid(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
|
||||
) -> InvitationDetailDict | None:
|
||||
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = db.session.scalar(
|
||||
tenant = session.scalar(
|
||||
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
|
||||
account = session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@ -2105,13 +2162,13 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def get_invitation_with_case_fallback(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
|
||||
) -> InvitationDetailDict | None:
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token, session=session)
|
||||
if invitation or not email or email == email.lower():
|
||||
return invitation
|
||||
normalized_email = email.lower()
|
||||
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
|
||||
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token, session=session)
|
||||
|
||||
|
||||
def _generate_refresh_token(length: int = 64):
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import delete, or_, select, update
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -300,17 +301,19 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
|
||||
def update_app_annotation_directly(
|
||||
cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str, session: scoped_session
|
||||
):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
app = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
annotation = session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
@ -326,9 +329,9 @@ class AppAnnotationService:
|
||||
annotation.content = answer
|
||||
annotation.question = question
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
app_annotation_setting = db.session.scalar(
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
@ -344,33 +347,33 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str, session: scoped_session):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
app = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
annotation = session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
db.session.delete(annotation)
|
||||
session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = db.session.scalars(
|
||||
annotation_hit_histories = session.scalars(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
).all()
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
session.delete(annotation_hit_history)
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
# if annotation reply is enabled , delete annotation index
|
||||
app_annotation_setting = db.session.scalar(
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
|
||||
@ -91,4 +91,4 @@ class OAuthServerService:
|
||||
|
||||
user_id_str = user_account_id.decode("utf-8")
|
||||
|
||||
return AccountService.load_user(user_id_str)
|
||||
return AccountService.load_user(user_id_str, db.session)
|
||||
|
||||
@ -36,7 +36,9 @@ class WorkspaceService:
|
||||
feature = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
can_replace_logo = feature.can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
|
||||
if can_replace_logo and TenantService.has_roles(
|
||||
tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], session=db.session
|
||||
):
|
||||
base_url = dify_config.FILES_URL
|
||||
replace_webapp_logo = (
|
||||
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
|
||||
|
||||
@ -84,6 +84,7 @@ def setup_account(request) -> Generator[Account, None, None]:
|
||||
password=secrets.token_hex(16),
|
||||
ip_address="localhost",
|
||||
language="en-US",
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
|
||||
@ -548,6 +548,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="en-US",
|
||||
timezone=None,
|
||||
session=ANY,
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
@ -581,6 +582,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="en-US",
|
||||
timezone=None,
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@ -612,6 +614,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@ -643,6 +646,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone=None,
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@ -673,7 +677,7 @@ class TestAccountGeneration:
|
||||
|
||||
assert result == mock_account
|
||||
assert oauth_new_user is False
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace", session=ANY)
|
||||
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||
mock_new_tenant, mock_account, ANY, role="owner"
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Literal
|
||||
from unittest.mock import patch
|
||||
@ -12,7 +12,6 @@ from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
|
||||
from models import Account, Tenant
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -46,20 +45,25 @@ def make_account(db_session_with_containers: Session) -> Callable[..., Account]:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
if with_owner_tenant:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
return account
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
def add_tenant_for_account(account: Account, *, role: str = "normal", name: str = "Second WS") -> Tenant:
|
||||
def add_tenant_for_account(
|
||||
account: Account, *, session: Session, role: str = "normal", name: str = "Second WS"
|
||||
) -> Tenant:
|
||||
"""Create an additional tenant and join ``account`` to it (real service calls)."""
|
||||
with patch("services.account_service.FeatureService") as mock_feature_service:
|
||||
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
|
||||
tenant = TenantService.create_tenant(name=name)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role=role)
|
||||
tenant = TenantService.create_tenant(name=name, session=session)
|
||||
TenantService.create_tenant_member(tenant, account, session, role=role)
|
||||
return tenant
|
||||
|
||||
|
||||
@ -93,7 +97,7 @@ def account_auth_context(
|
||||
*,
|
||||
token_id: uuid.UUID,
|
||||
client_id: str = "integration-cli",
|
||||
) -> Iterator[AuthContext]:
|
||||
) -> Generator[AuthContext]:
|
||||
"""Publish an account ``AuthContext`` for handlers that read ``get_auth_ctx()``.
|
||||
|
||||
The auth pipeline normally sets this ContextVar; the integration suite
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Callable
|
||||
from inspect import unwrap
|
||||
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.openapi.account import AccountApi
|
||||
from models import Account
|
||||
@ -34,11 +35,13 @@ class TestAccountInfo:
|
||||
# the only workspace the account belongs to.
|
||||
assert result.default_workspace_id == owner_tenant.id
|
||||
|
||||
def test_lists_all_joined_workspaces(self, app: Flask, make_account: Callable[..., Account]) -> None:
|
||||
def test_lists_all_joined_workspaces(
|
||||
self, app: Flask, db_session_with_containers: Session, make_account: Callable[..., Account]
|
||||
) -> None:
|
||||
account = make_account()
|
||||
owner_tenant = account.current_tenant
|
||||
assert owner_tenant is not None
|
||||
second = add_tenant_for_account(account, role="normal", name="Second WS")
|
||||
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
|
||||
|
||||
api = AccountApi()
|
||||
with app.test_request_context("/openapi/v1/account"):
|
||||
|
||||
@ -81,8 +81,9 @@ def _app_and_account(db_session: Session, *, mode: str = "chat") -> tuple[App, A
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session)
|
||||
tenant = account.current_tenant
|
||||
assert tenant is not None
|
||||
app_args = CreateAppParams(
|
||||
|
||||
@ -41,7 +41,7 @@ class TestWorkspacesList:
|
||||
account = make_account()
|
||||
owner_tenant = account.current_tenant
|
||||
assert owner_tenant is not None
|
||||
second = add_tenant_for_account(account, role="normal", name="Second WS")
|
||||
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
|
||||
|
||||
api = WorkspacesApi()
|
||||
with app.test_request_context("/openapi/v1/workspaces"):
|
||||
@ -90,7 +90,9 @@ class TestWorkspaceSwitch:
|
||||
account = make_account()
|
||||
owner_tenant = account.current_tenant
|
||||
assert owner_tenant is not None
|
||||
target = add_tenant_for_account(account, role="normal", name="Switch Target")
|
||||
target = add_tenant_for_account(
|
||||
account, session=db_session_with_containers, role="normal", name="Switch Target"
|
||||
)
|
||||
|
||||
api = WorkspaceSwitchApi()
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{target.id}/switch", method="POST"):
|
||||
|
||||
@ -27,8 +27,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -88,8 +89,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -141,8 +143,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -194,8 +197,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -257,8 +261,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -291,8 +296,11 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account1, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
|
||||
tenant1 = account1.current_tenant
|
||||
|
||||
account2 = AccountService.create_account(
|
||||
@ -300,8 +308,11 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account2, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
|
||||
tenant2 = account2.current_tenant
|
||||
|
||||
# Create dataset for tenant1
|
||||
@ -367,8 +378,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Don't create any datasets
|
||||
@ -391,8 +403,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create multiple datasets
|
||||
@ -452,8 +465,9 @@ class TestKnowledgeRetrievalIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -520,8 +534,9 @@ class TestKnowledgeRetrievalIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset but no documents
|
||||
@ -568,8 +583,9 @@ class TestKnowledgeRetrievalIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -114,8 +114,9 @@ class TestAgentService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
|
||||
@ -81,8 +81,9 @@ class TestAnnotationService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
@ -280,7 +281,9 @@ class TestAnnotationService:
|
||||
"question": fake.sentence(),
|
||||
"answer": fake.text(max_nb_chars=200),
|
||||
}
|
||||
updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id)
|
||||
updated_annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
updated_args, app.id, annotation.id, db_session_with_containers
|
||||
)
|
||||
|
||||
# Verify annotation was updated correctly
|
||||
assert updated_annotation.id == annotation.id
|
||||
@ -567,7 +570,7 @@ class TestAnnotationService:
|
||||
annotation_id = annotation.id
|
||||
|
||||
# Delete the annotation
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers)
|
||||
|
||||
# Verify annotation was deleted
|
||||
|
||||
@ -595,7 +598,7 @@ class TestAnnotationService:
|
||||
|
||||
# Try to delete annotation with non-existent app
|
||||
with pytest.raises(NotFound, match="App not found"):
|
||||
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id, db_session_with_containers)
|
||||
|
||||
def test_delete_app_annotation_annotation_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -609,7 +612,7 @@ class TestAnnotationService:
|
||||
|
||||
# Try to delete non-existent annotation
|
||||
with pytest.raises(NotFound, match="Annotation not found"):
|
||||
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id, db_session_with_containers)
|
||||
|
||||
def test_enable_app_annotation_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -1225,7 +1228,9 @@ class TestAnnotationService:
|
||||
"question": fake.sentence(),
|
||||
"answer": fake.text(max_nb_chars=200),
|
||||
}
|
||||
updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id)
|
||||
updated_annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
updated_args, app.id, annotation.id, db_session_with_containers
|
||||
)
|
||||
|
||||
# Verify annotation was updated correctly
|
||||
assert updated_annotation.id == annotation.id
|
||||
@ -1295,7 +1300,7 @@ class TestAnnotationService:
|
||||
mock_external_service_dependencies["delete_task"].delay.reset_mock()
|
||||
|
||||
# Delete the annotation
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers)
|
||||
|
||||
# Verify annotation was deleted
|
||||
deleted_annotation = (
|
||||
|
||||
@ -57,8 +57,9 @@ class TestAPIBasedExtensionService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
@ -145,8 +145,11 @@ class TestAppDslService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
|
||||
@ -166,8 +166,9 @@ class TestAppGenerateService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
@ -62,8 +62,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
@ -119,8 +120,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
@ -162,8 +164,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -210,8 +213,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
@ -261,8 +265,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
@ -344,8 +349,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
from models import AppStar
|
||||
@ -404,8 +410,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService, CreateAppParams, StarredAppListParams
|
||||
@ -500,8 +507,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
@ -566,14 +574,18 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
first_account, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(first_account, name=fake.company())
|
||||
tenant = first_account.current_tenant
|
||||
second_account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
@ -623,8 +635,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
@ -685,8 +698,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -755,8 +769,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
@ -807,8 +822,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
@ -857,8 +873,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -910,8 +927,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -971,8 +989,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1030,8 +1049,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1089,8 +1109,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1139,8 +1160,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1190,8 +1212,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1249,8 +1272,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1287,8 +1311,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1326,8 +1351,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
@ -1375,8 +1401,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import CreateAppParams
|
||||
@ -1411,8 +1438,9 @@ class TestAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
|
||||
@ -98,8 +98,9 @@ class TestMessageService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
@ -648,8 +649,11 @@ class TestMessageService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
other_account, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
|
||||
|
||||
# Test getting message with different user
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -172,4 +172,4 @@ class TestOAuthServerServiceTokenOperations:
|
||||
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
|
||||
|
||||
assert result is expected_user
|
||||
mock_load.assert_called_once_with("user-88")
|
||||
mock_load.assert_called_once_with("user-88", ANY)
|
||||
|
||||
@ -51,8 +51,9 @@ class TestOpsService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
|
||||
@ -68,8 +68,9 @@ class TestSavedMessageService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
|
||||
@ -55,6 +55,7 @@ class TestTriggerProviderService:
|
||||
def _create_test_account_and_tenant(
|
||||
self,
|
||||
mock_external_service_dependencies: MockExternalServiceDependencies,
|
||||
db_session_with_containers: Session,
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
@ -83,8 +84,9 @@ class TestTriggerProviderService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
assert tenant is not None
|
||||
|
||||
@ -164,7 +166,9 @@ class TestTriggerProviderService:
|
||||
- Database state is correctly updated
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
mock_external_service_dependencies, db_session_with_containers
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -262,7 +266,9 @@ class TestTriggerProviderService:
|
||||
- Merged credentials contain only new values
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
mock_external_service_dependencies, db_session_with_containers
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -320,7 +326,9 @@ class TestTriggerProviderService:
|
||||
- Original credentials are preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
mock_external_service_dependencies, db_session_with_containers
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -376,7 +384,9 @@ class TestTriggerProviderService:
|
||||
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
mock_external_service_dependencies, db_session_with_containers
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -434,7 +444,9 @@ class TestTriggerProviderService:
|
||||
- Original subscription state is preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
mock_external_service_dependencies, db_session_with_containers
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -474,9 +486,8 @@ class TestTriggerProviderService:
|
||||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
@pytest.mark.usefixtures("db_session_with_containers")
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, mock_external_service_dependencies: MockExternalServiceDependencies
|
||||
self, mock_external_service_dependencies: MockExternalServiceDependencies, db_session_with_containers: Session
|
||||
) -> None:
|
||||
"""
|
||||
Test error when subscription is not found.
|
||||
@ -485,7 +496,9 @@ class TestTriggerProviderService:
|
||||
- Proper error is raised when subscription doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
mock_external_service_dependencies, db_session_with_containers
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
fake_subscription_id = fake.uuid4()
|
||||
@ -509,7 +522,9 @@ class TestTriggerProviderService:
|
||||
- Error is raised when new name conflicts with existing subscription
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
mock_external_service_dependencies, db_session_with_containers
|
||||
)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
@ -72,8 +72,9 @@ class TestWebConversationService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
|
||||
@ -63,8 +63,9 @@ class TestWebhookService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
assert tenant is not None
|
||||
|
||||
|
||||
@ -78,8 +78,9 @@ class TestWorkflowAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
@ -126,8 +127,9 @@ class TestWorkflowAppService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
return tenant, account
|
||||
|
||||
@ -74,8 +74,9 @@ class TestWorkflowRunService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
@ -530,8 +531,9 @@ class TestWorkflowRunService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
interface_language="en-US",
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
@ -581,8 +583,9 @@ class TestWorkflowRunService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
interface_language="en-US",
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
@ -632,8 +635,9 @@ class TestWorkflowRunService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
interface_language="en-US",
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
|
||||
@ -89,8 +89,9 @@ class TestWorkflowToolManageService:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
|
||||
@ -90,8 +90,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -211,8 +212,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -255,8 +257,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Test different index types
|
||||
@ -342,8 +345,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -424,8 +428,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -525,8 +530,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -625,8 +631,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -717,8 +724,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -820,8 +828,11 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
accounts.append(account)
|
||||
tenants.append(tenant)
|
||||
@ -926,8 +937,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -1031,8 +1043,9 @@ class TestCleanNotionDocumentTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset with built-in fields enabled
|
||||
|
||||
@ -67,8 +67,9 @@ class TestDealDatasetVectorIndexTask:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
assert tenant is not None
|
||||
return account, tenant
|
||||
|
||||
@ -8,7 +8,7 @@ This module tests the account activation mechanism including:
|
||||
- Initial login after activation
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -41,7 +41,7 @@ class TestActivateCheckApi:
|
||||
}
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
|
||||
def test_check_valid_invitation_token(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock):
|
||||
"""
|
||||
Test checking valid invitation token.
|
||||
|
||||
@ -67,7 +67,9 @@ class TestActivateCheckApi:
|
||||
assert response["data"]["email"] == "invitee@example.com"
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_valid_invitation_token_includes_account_status(self, mock_get_invitation, app, mock_invitation):
|
||||
def test_check_valid_invitation_token_includes_account_status(
|
||||
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
|
||||
):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = AccountStatus.ACTIVE
|
||||
mock_invitation["account"] = mock_account
|
||||
@ -103,7 +105,9 @@ class TestActivateCheckApi:
|
||||
assert response["is_valid"] is False
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
|
||||
def test_check_token_without_workspace_id(
|
||||
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
|
||||
):
|
||||
"""
|
||||
Test checking token without workspace ID.
|
||||
|
||||
@ -121,10 +125,10 @@ class TestActivateCheckApi:
|
||||
|
||||
# Assert
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
|
||||
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token", session=ANY)
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
|
||||
def test_check_token_without_email(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation):
|
||||
"""
|
||||
Test checking token without email parameter.
|
||||
|
||||
@ -142,10 +146,12 @@ class TestActivateCheckApi:
|
||||
|
||||
# Assert
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token", session=ANY)
|
||||
|
||||
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
|
||||
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
|
||||
def test_check_token_normalizes_email_to_lowercase(
|
||||
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
|
||||
):
|
||||
"""Ensure token validation uses lowercase emails."""
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
|
||||
@ -156,7 +162,7 @@ class TestActivateCheckApi:
|
||||
response = api.get()
|
||||
|
||||
assert response["is_valid"] is True
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY)
|
||||
|
||||
|
||||
class TestActivateApi:
|
||||
@ -554,7 +560,7 @@ class TestActivateApi:
|
||||
response = api.post()
|
||||
|
||||
assert response["result"] == "success"
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
|
||||
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY)
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.TenantService.create_tenant_member")
|
||||
@ -593,7 +599,7 @@ class TestActivateApi:
|
||||
mock_create_tenant_member.assert_called_once_with(
|
||||
mock_invitation["tenant"], mock_account, mock_db.session, role=TenantAccountRole.ADMIN
|
||||
)
|
||||
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id)
|
||||
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY)
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
|
||||
@patch("controllers.console.auth.activate.TenantService.create_tenant_member")
|
||||
@ -628,5 +634,5 @@ class TestActivateApi:
|
||||
|
||||
assert response["result"] == "success"
|
||||
mock_create_tenant_member.assert_not_called()
|
||||
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id)
|
||||
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY)
|
||||
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
@ -25,6 +25,7 @@ def test_create_new_account_uses_requested_language(mock_create_account):
|
||||
password="ValidPass123!",
|
||||
interface_language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ This module tests the email code login mechanism including:
|
||||
"""
|
||||
|
||||
import base64
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -368,6 +368,7 @@ class TestEmailCodeLoginApi:
|
||||
name="newuser@example.com",
|
||||
interface_language="en-US",
|
||||
timezone="Asia/Shanghai",
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
|
||||
@ -9,7 +9,7 @@ This module tests the core authentication endpoints including:
|
||||
"""
|
||||
|
||||
import base64
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -129,7 +129,7 @@ class TestLoginApi:
|
||||
response = login_api.post()
|
||||
|
||||
# Assert
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None, session=ANY)
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_rate_limit.assert_called_once_with("test@example.com")
|
||||
assert response.json["result"] == "success"
|
||||
@ -184,7 +184,7 @@ class TestLoginApi:
|
||||
response = login_api.post()
|
||||
|
||||
# Assert
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token")
|
||||
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token", session=ANY)
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@ -407,13 +407,13 @@ class TestLoginApi:
|
||||
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
|
||||
def test_login_retries_with_lowercase_email(
|
||||
self,
|
||||
mock_reset_rate_limit,
|
||||
mock_login_service,
|
||||
mock_get_tenants,
|
||||
mock_add_rate_limit,
|
||||
mock_authenticate,
|
||||
mock_get_invitation,
|
||||
mock_is_rate_limit,
|
||||
mock_reset_rate_limit: MagicMock,
|
||||
mock_login_service: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
mock_add_rate_limit: MagicMock,
|
||||
mock_authenticate: MagicMock,
|
||||
mock_get_invitation: MagicMock,
|
||||
mock_is_rate_limit: MagicMock,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
mock_account,
|
||||
@ -435,8 +435,8 @@ class TestLoginApi:
|
||||
|
||||
assert response.json["result"] == "success"
|
||||
assert mock_authenticate.call_args_list == [
|
||||
(("Upper@Example.com", "ValidPass123!", None), {}),
|
||||
(("upper@example.com", "ValidPass123!", None), {}),
|
||||
(("Upper@Example.com", "ValidPass123!", None), {"session": ANY}),
|
||||
(("upper@example.com", "ValidPass123!", None), {"session": ANY}),
|
||||
]
|
||||
mock_add_rate_limit.assert_not_called()
|
||||
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
|
||||
@ -447,10 +447,10 @@ class TestLoginApi:
|
||||
@patch("controllers.console.auth.login._get_account_with_case_fallback")
|
||||
def test_email_code_login_logs_banned_account(
|
||||
self,
|
||||
mock_get_account,
|
||||
mock_revoke_token,
|
||||
mock_get_token_data,
|
||||
mock_db,
|
||||
mock_get_account: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_get_token_data: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
|
||||
@ -491,7 +491,9 @@ class TestLogoutApi:
|
||||
|
||||
@patch("controllers.console.auth.login.AccountService.logout")
|
||||
@patch("controllers.console.auth.login.flask_login.logout_user")
|
||||
def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account):
|
||||
def test_successful_logout(
|
||||
self, mock_logout_user: MagicMock, mock_service_logout: MagicMock, app: Flask, mock_account
|
||||
):
|
||||
"""
|
||||
Test successful logout flow.
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -66,8 +66,9 @@ def test_generate_account_registers_with_browser_timezone(
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
session=ANY,
|
||||
)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY)
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
|
||||
@ -97,8 +98,9 @@ def test_generate_account_prefers_state_language_over_accept_language(
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone=None,
|
||||
session=ANY,
|
||||
)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY)
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
|
||||
@ -8,7 +8,7 @@ This module tests the token refresh mechanism including:
|
||||
- Error handling for invalid tokens
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -70,7 +70,7 @@ class TestRefreshTokenApi:
|
||||
|
||||
# Assert
|
||||
mock_extract_token.assert_called_once()
|
||||
mock_refresh_token.assert_called_once_with("valid_refresh_token")
|
||||
mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY)
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@ -191,7 +191,7 @@ class TestRefreshTokenApi:
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
# Verify new token pair was generated
|
||||
mock_refresh_token.assert_called_once_with("valid_refresh_token")
|
||||
mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY)
|
||||
# In real implementation, cookies would be set with new values
|
||||
assert mock_token_pair.access_token == "new_access_token"
|
||||
assert mock_token_pair.refresh_token == "new_refresh_token"
|
||||
|
||||
@ -38,7 +38,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 1)
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
@ -48,7 +48,7 @@ def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.Mo
|
||||
|
||||
def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
@ -60,7 +60,7 @@ def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.M
|
||||
|
||||
def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -394,12 +394,12 @@ class TestChangeEmailReset:
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_new_email_before_update(
|
||||
self,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_is_freeze: MagicMock,
|
||||
mock_check_unique: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_update_account: MagicMock,
|
||||
mock_send_notify: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
@ -424,9 +424,9 @@ class TestChangeEmailReset:
|
||||
method(api, current_user)
|
||||
|
||||
mock_is_freeze.assert_called_once_with("new@example.com")
|
||||
mock_check_unique.assert_called_once_with("new@example.com")
|
||||
mock_check_unique.assert_called_once_with("new@example.com", session=ANY)
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
|
||||
mock_update_account.assert_called_once_with(current_user, email="new@example.com", session=ANY)
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@ -437,12 +437,12 @@ class TestChangeEmailReset:
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_reject_reset_when_token_phase_is_not_new_verified(
|
||||
self,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_is_freeze: MagicMock,
|
||||
mock_check_unique: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_update_account: MagicMock,
|
||||
mock_send_notify: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
|
||||
@ -480,12 +480,12 @@ class TestChangeEmailReset:
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
|
||||
self,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_is_freeze: MagicMock,
|
||||
mock_check_unique: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_update_account: MagicMock,
|
||||
mock_send_notify: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
"""A verified token for address A must not be replayed to change to address B."""
|
||||
@ -523,12 +523,12 @@ class TestChangeEmailReset:
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_reject_reset_when_token_account_id_does_not_match_current_user(
|
||||
self,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_is_freeze: MagicMock,
|
||||
mock_check_unique: MagicMock,
|
||||
mock_get_data: MagicMock,
|
||||
mock_revoke_token: MagicMock,
|
||||
mock_update_account: MagicMock,
|
||||
mock_send_notify: MagicMock,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
@ -575,9 +575,9 @@ class TestAccountServiceSendChangeEmailEmail:
|
||||
@patch("services.account_service.AccountService.generate_change_email_token")
|
||||
def test_should_bind_account_id_and_target_email_into_generated_token(
|
||||
self,
|
||||
mock_generate_token,
|
||||
mock_rate_limiter,
|
||||
mock_mail_task,
|
||||
mock_generate_token: MagicMock,
|
||||
mock_rate_limiter: MagicMock,
|
||||
mock_mail_task: MagicMock,
|
||||
):
|
||||
mock_rate_limiter.is_rate_limited.return_value = False
|
||||
mock_generate_token.return_value = "the-token"
|
||||
@ -665,7 +665,7 @@ class TestAccountDeletionFeedback:
|
||||
class TestCheckEmailUnique:
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask):
|
||||
def test_should_normalize_email(self, mock_is_freeze: MagicMock, mock_check_unique: MagicMock, app: Flask):
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
@ -680,7 +680,7 @@ class TestCheckEmailUnique:
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_is_freeze.assert_called_once_with("case@test.com")
|
||||
mock_check_unique.assert_called_once_with("case@test.com")
|
||||
mock_check_unique.assert_called_once_with("case@test.com", session=ANY)
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
|
||||
@ -8,7 +8,7 @@ handler tests use inspect.unwrap() to bypass them and focus on business logic.
|
||||
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -115,7 +115,7 @@ class TestEnterpriseWorkspace:
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["name"] == "My Workspace"
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY)
|
||||
mock_tenant_svc.create_tenant_member.assert_called_once_with(
|
||||
mock_tenant, mock_account, mock_db.session, role="owner"
|
||||
)
|
||||
@ -183,5 +183,5 @@ class TestEnterpriseWorkspaceNoOwnerEmail:
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["encrypt_public_key"] == "pub-key"
|
||||
assert result["tenant"]["custom_config"] == {}
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY)
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
|
||||
@ -179,7 +179,7 @@ class TestAccountService:
|
||||
mock_password_dependencies["compare_password"].return_value = True
|
||||
|
||||
# Execute test
|
||||
result = AccountService.authenticate("test@example.com", "password")
|
||||
result = AccountService.authenticate("test@example.com", "password", session=mock_db_dependencies["db"].session)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_account
|
||||
@ -191,7 +191,11 @@ class TestAccountService:
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password"
|
||||
AccountPasswordError,
|
||||
AccountService.authenticate,
|
||||
"notfound@example.com",
|
||||
"password",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
def test_authenticate_account_banned(self, mock_db_dependencies):
|
||||
@ -202,7 +206,13 @@ class TestAccountService:
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_account
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
|
||||
self._assert_exception_raised(
|
||||
AccountLoginError,
|
||||
AccountService.authenticate,
|
||||
"banned@example.com",
|
||||
"password",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
def test_authenticate_password_error(self, mock_db_dependencies, mock_password_dependencies):
|
||||
"""Test authentication with wrong password."""
|
||||
@ -215,7 +225,11 @@ class TestAccountService:
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
AccountPasswordError, AccountService.authenticate, "test@example.com", "wrongpassword"
|
||||
AccountPasswordError,
|
||||
AccountService.authenticate,
|
||||
"test@example.com",
|
||||
"wrongpassword",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
def test_authenticate_pending_account_activates(self, mock_db_dependencies, mock_password_dependencies):
|
||||
@ -228,7 +242,9 @@ class TestAccountService:
|
||||
mock_password_dependencies["compare_password"].return_value = True
|
||||
|
||||
# Execute test
|
||||
result = AccountService.authenticate("pending@example.com", "password")
|
||||
result = AccountService.authenticate(
|
||||
"pending@example.com", "password", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_account
|
||||
@ -253,6 +269,7 @@ class TestAccountService:
|
||||
interface_language="en-US",
|
||||
password="password123",
|
||||
interface_theme="light",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
@ -290,6 +307,7 @@ class TestAccountService:
|
||||
interface_language="en-US",
|
||||
password="password123",
|
||||
timezone="Asia/Shanghai",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
assert result.timezone == "Asia/Shanghai"
|
||||
@ -309,6 +327,7 @@ class TestAccountService:
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
session=MagicMock(),
|
||||
)
|
||||
|
||||
def test_create_account_email_frozen(self, mock_db_dependencies, mock_external_service_dependencies):
|
||||
@ -325,6 +344,7 @@ class TestAccountService:
|
||||
email="frozen@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
dify_config.BILLING_ENABLED = False
|
||||
|
||||
@ -341,6 +361,7 @@ class TestAccountService:
|
||||
interface_language="zh-CN",
|
||||
password=None,
|
||||
interface_theme="dark",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
@ -375,7 +396,9 @@ class TestAccountService:
|
||||
mock_password_dependencies["hash_password"].return_value = b"new_hashed_password"
|
||||
|
||||
# Execute test
|
||||
result = AccountService.update_account_password(mock_account, "old_password", "new_password123")
|
||||
result = AccountService.update_account_password(
|
||||
mock_account, "old_password", "new_password123", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_account
|
||||
@ -391,7 +414,7 @@ class TestAccountService:
|
||||
# Verify database operations
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_update_account_password_current_password_incorrect(self, mock_password_dependencies):
|
||||
def test_update_account_password_current_password_incorrect(self, mock_db_dependencies, mock_password_dependencies):
|
||||
"""Test password update with incorrect current password."""
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
@ -404,6 +427,7 @@ class TestAccountService:
|
||||
mock_account,
|
||||
"wrong_password",
|
||||
"new_password123",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify password comparison was called
|
||||
@ -411,7 +435,7 @@ class TestAccountService:
|
||||
"wrong_password", "hashed_password", "salt"
|
||||
)
|
||||
|
||||
def test_update_account_password_invalid_new_password(self, mock_password_dependencies):
|
||||
def test_update_account_password_invalid_new_password(self, mock_db_dependencies, mock_password_dependencies):
|
||||
"""Test password update with invalid new password."""
|
||||
# Setup test data
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
@ -420,7 +444,12 @@ class TestAccountService:
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(
|
||||
ValueError, AccountService.update_account_password, mock_account, "old_password", "short"
|
||||
ValueError,
|
||||
AccountService.update_account_password,
|
||||
mock_account,
|
||||
"old_password",
|
||||
"short",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify password validation was called
|
||||
@ -447,19 +476,19 @@ class TestAccountService:
|
||||
mock_datetime.UTC = "UTC"
|
||||
|
||||
# Execute test
|
||||
result = AccountService.load_user("user-123")
|
||||
result = AccountService.load_user("user-123", mock_db_dependencies["db"].session)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_account
|
||||
assert mock_account.set_tenant_id.called
|
||||
mock_refresh_last_active.assert_called_once_with(mock_account)
|
||||
mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session)
|
||||
|
||||
def test_load_user_not_found(self, mock_db_dependencies):
|
||||
"""Test user loading when user does not exist."""
|
||||
mock_db_dependencies["db"].session.get.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = AccountService.load_user("non-existent-user")
|
||||
result = AccountService.load_user("non-existent-user", mock_db_dependencies["db"].session)
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
@ -500,14 +529,14 @@ class TestAccountService:
|
||||
mock_naive_utc_now.return_value = mock_now
|
||||
|
||||
# Execute test
|
||||
result = AccountService.load_user("user-123")
|
||||
result = AccountService.load_user("user-123", mock_db_dependencies["db"].session)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_account
|
||||
assert mock_available_tenant.current is True
|
||||
assert mock_available_tenant.last_opened_at == mock_now
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
mock_refresh_last_active.assert_called_once_with(mock_account)
|
||||
mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session)
|
||||
|
||||
def test_load_user_no_tenants(self, mock_db_dependencies):
|
||||
"""Test user loading when user has no tenants at all."""
|
||||
@ -525,7 +554,7 @@ class TestAccountService:
|
||||
mock_datetime.UTC = "UTC"
|
||||
|
||||
# Execute test
|
||||
result = AccountService.load_user("user-123")
|
||||
result = AccountService.load_user("user-123", mock_db_dependencies["db"].session)
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
@ -542,7 +571,7 @@ class TestAccountService:
|
||||
):
|
||||
mock_redis_client.set.return_value = True
|
||||
|
||||
AccountService._refresh_account_last_active(mock_account)
|
||||
AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session)
|
||||
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
"account_last_active_refresh:user-123",
|
||||
@ -565,7 +594,7 @@ class TestAccountService:
|
||||
):
|
||||
mock_redis_client.set.return_value = None
|
||||
|
||||
AccountService._refresh_account_last_active(mock_account)
|
||||
AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session)
|
||||
|
||||
mock_redis_client.set.assert_called_once_with(
|
||||
"account_last_active_refresh:user-123",
|
||||
@ -586,7 +615,7 @@ class TestAccountService:
|
||||
patch("services.account_service.naive_utc_now", return_value=now),
|
||||
patch("services.account_service.redis_client") as mock_redis_client,
|
||||
):
|
||||
AccountService._refresh_account_last_active(mock_account)
|
||||
AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session)
|
||||
|
||||
mock_redis_client.set.assert_not_called()
|
||||
mock_db_dependencies["db"].session.execute.assert_not_called()
|
||||
@ -736,7 +765,9 @@ class TestTenantService:
|
||||
mock_credit_pool_db.session.commit = MagicMock()
|
||||
|
||||
# Execute test
|
||||
TenantService.create_owner_tenant_if_not_exist(mock_account)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
mock_account, session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
# Verify tenant was created with correct parameters
|
||||
mock_db_dependencies["db"].session.add.assert_called()
|
||||
@ -838,7 +869,9 @@ class TestTenantService:
|
||||
mock_sync.return_value = True
|
||||
|
||||
# Act
|
||||
TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator)
|
||||
TenantService.remove_member_from_tenant(
|
||||
mock_tenant, mock_pending_member, mock_operator, session=mock_db.session
|
||||
)
|
||||
|
||||
# Assert: enterprise sync still receives the correct member ID
|
||||
mock_sync.assert_called_once_with(
|
||||
@ -878,7 +911,9 @@ class TestTenantService:
|
||||
mock_sync.return_value = True
|
||||
|
||||
# Act
|
||||
TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator)
|
||||
TenantService.remove_member_from_tenant(
|
||||
mock_tenant, mock_pending_member, mock_operator, session=mock_db.session
|
||||
)
|
||||
|
||||
# Assert: only the join record should be deleted, not the account
|
||||
mock_db.session.delete.assert_called_once_with(mock_ta)
|
||||
@ -909,7 +944,9 @@ class TestTenantService:
|
||||
mock_sync.return_value = True
|
||||
|
||||
# Act
|
||||
TenantService.remove_member_from_tenant(mock_tenant, mock_active_member, mock_operator)
|
||||
TenantService.remove_member_from_tenant(
|
||||
mock_tenant, mock_active_member, mock_operator, session=mock_db.session
|
||||
)
|
||||
|
||||
# Assert: only the join record should be deleted
|
||||
mock_db.session.delete.assert_called_once_with(mock_ta)
|
||||
@ -934,7 +971,7 @@ class TestTenantService:
|
||||
mock_naive_utc_now.return_value = mock_now
|
||||
|
||||
# Execute test
|
||||
TenantService.switch_tenant(mock_account, "tenant-456")
|
||||
TenantService.switch_tenant(mock_account, "tenant-456", session=mock_db.session)
|
||||
|
||||
# Verify tenant was switched
|
||||
assert mock_tenant_join.current is True
|
||||
@ -947,7 +984,7 @@ class TestTenantService:
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
|
||||
|
||||
# Execute test and verify exception
|
||||
self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None)
|
||||
self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None, session=MagicMock())
|
||||
|
||||
# ==================== Role Management Tests ====================
|
||||
|
||||
@ -971,7 +1008,7 @@ class TestTenantService:
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
|
||||
|
||||
# Execute test
|
||||
TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator)
|
||||
TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator, session=mock_db.session)
|
||||
|
||||
# Verify role was updated
|
||||
assert mock_target_join.role == "admin"
|
||||
@ -1005,7 +1042,9 @@ class TestTenantService:
|
||||
):
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(mock_account, is_setup=True)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
mock_account, is_setup=True, session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
mock_rbac_service.MemberRoles.replace.assert_called_once_with(
|
||||
tenant_id="tenant-rbac",
|
||||
@ -1030,7 +1069,7 @@ class TestTenantService:
|
||||
with patch("services.account_service.db") as mock_db:
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
|
||||
|
||||
TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator)
|
||||
TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session)
|
||||
|
||||
assert mock_target_join.role == "editor"
|
||||
self._assert_database_operations_called(mock_db)
|
||||
@ -1052,7 +1091,9 @@ class TestTenantService:
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator)
|
||||
TenantService.update_member_role(
|
||||
mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session
|
||||
)
|
||||
|
||||
def test_admin_cannot_promote_member_to_owner(self):
|
||||
"""Test admin cannot promote a non-owner member to owner."""
|
||||
@ -1071,7 +1112,9 @@ class TestTenantService:
|
||||
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
TenantService.update_member_role(mock_tenant, mock_member, "owner", mock_operator)
|
||||
TenantService.update_member_role(
|
||||
mock_tenant, mock_member, "owner", mock_operator, session=mock_db.session
|
||||
)
|
||||
|
||||
# ==================== Permission Check Tests ====================
|
||||
|
||||
@ -1089,7 +1132,9 @@ class TestTenantService:
|
||||
mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join
|
||||
|
||||
# Execute test - should not raise exception
|
||||
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")
|
||||
TenantService.check_member_permission(
|
||||
mock_tenant, mock_operator, mock_member, "add", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
def test_check_member_permission_operate_self(self):
|
||||
"""Test member permission check when operator tries to operate self."""
|
||||
@ -1108,6 +1153,7 @@ class TestTenantService:
|
||||
mock_operator,
|
||||
mock_operator, # Same as operator
|
||||
"add",
|
||||
session=MagicMock(),
|
||||
)
|
||||
|
||||
def test_admin_can_remove_non_owner_member(self, mock_db_dependencies):
|
||||
@ -1124,7 +1170,9 @@ class TestTenantService:
|
||||
)
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join]
|
||||
|
||||
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
|
||||
TenantService.check_member_permission(
|
||||
mock_tenant, mock_operator, mock_member, "remove", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
def test_admin_cannot_remove_owner_member(self, mock_db_dependencies):
|
||||
"""Test admin cannot remove an owner member."""
|
||||
@ -1141,7 +1189,9 @@ class TestTenantService:
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join]
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
|
||||
TenantService.check_member_permission(
|
||||
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
|
||||
)
|
||||
|
||||
def test_rbac_member_can_remove_non_owner_member(self):
|
||||
"""Test RBAC workspace.member.manage allows removing a non-owner member."""
|
||||
@ -1158,7 +1208,9 @@ class TestTenantService:
|
||||
patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions),
|
||||
patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=False),
|
||||
):
|
||||
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
|
||||
TenantService.check_member_permission(
|
||||
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
|
||||
)
|
||||
|
||||
def test_rbac_member_cannot_remove_without_permission(self):
|
||||
"""Test RBAC permission check rejects removal without workspace.member.manage."""
|
||||
@ -1175,7 +1227,9 @@ class TestTenantService:
|
||||
patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions),
|
||||
):
|
||||
with pytest.raises(NoPermissionError):
|
||||
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
|
||||
TenantService.check_member_permission(
|
||||
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
|
||||
)
|
||||
|
||||
def test_rbac_member_cannot_remove_owner_member(self):
|
||||
"""Test RBAC permission check rejects removing an owner member."""
|
||||
@ -1193,7 +1247,9 @@ class TestTenantService:
|
||||
patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=True),
|
||||
):
|
||||
with pytest.raises(NoPermissionError):
|
||||
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
|
||||
TenantService.check_member_permission(
|
||||
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
|
||||
)
|
||||
|
||||
def test_get_rbac_workspace_owner_account_id(self):
|
||||
mock_roles = MagicMock()
|
||||
@ -1304,7 +1360,14 @@ class TestRegisterService:
|
||||
mock_dify_setup.return_value = mock_dify_setup_instance
|
||||
|
||||
# Execute test
|
||||
RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1", "en-US")
|
||||
RegisterService.setup(
|
||||
"admin@example.com",
|
||||
"Admin User",
|
||||
"password123",
|
||||
"192.168.1.1",
|
||||
"en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
mock_create_account.assert_called_once_with(
|
||||
@ -1313,8 +1376,11 @@ class TestRegisterService:
|
||||
interface_language="en-US",
|
||||
password="password123",
|
||||
is_setup=True,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
mock_create_tenant.assert_called_once_with(
|
||||
account=mock_account, is_setup=True, session=mock_db_dependencies["db"].session
|
||||
)
|
||||
mock_create_tenant.assert_called_once_with(account=mock_account, is_setup=True)
|
||||
mock_dify_setup.assert_called_once()
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
@ -1337,6 +1403,7 @@ class TestRegisterService:
|
||||
"password123",
|
||||
"192.168.1.1",
|
||||
"en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify rollback operations were called
|
||||
@ -1369,10 +1436,13 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password=None,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
assert result == mock_account
|
||||
mock_create_workspace.assert_called_once_with(account=mock_account)
|
||||
mock_create_workspace.assert_called_once_with(
|
||||
account=mock_account, session=mock_db_dependencies["db"].session
|
||||
)
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
|
||||
def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled(
|
||||
@ -1400,9 +1470,12 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password=None,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
mock_create_workspace.assert_called_once_with(account=mock_account)
|
||||
mock_create_workspace.assert_called_once_with(
|
||||
account=mock_account, session=mock_db_dependencies["db"].session
|
||||
)
|
||||
mock_join_default_workspace.assert_not_called()
|
||||
|
||||
def test_create_account_and_tenant_still_calls_default_workspace_join_when_workspace_creation_fails(
|
||||
@ -1433,6 +1506,7 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password=None,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
@ -1470,6 +1544,7 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
@ -1483,8 +1558,11 @@ class TestRegisterService:
|
||||
password="password123",
|
||||
is_setup=False,
|
||||
timezone=None,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
mock_create_tenant.assert_called_once_with(
|
||||
"Test User's Workspace", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
mock_create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_create_member.assert_called_once_with(
|
||||
mock_tenant, mock_account, mock_db_dependencies["db"].session, role="owner"
|
||||
)
|
||||
@ -1516,6 +1594,7 @@ class TestRegisterService:
|
||||
password="password123",
|
||||
language="en-US",
|
||||
create_workspace_required=False,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
assert result == mock_account
|
||||
@ -1546,6 +1625,7 @@ class TestRegisterService:
|
||||
password="password123",
|
||||
language="en-US",
|
||||
create_workspace_required=False,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
mock_join_default_workspace.assert_not_called()
|
||||
@ -1584,6 +1664,7 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
@ -1623,6 +1704,7 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
@ -1665,11 +1747,14 @@ class TestRegisterService:
|
||||
open_id="oauth123",
|
||||
provider="google",
|
||||
language="en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result == mock_account
|
||||
mock_link_account.assert_called_once_with("google", "oauth123", mock_account)
|
||||
mock_link_account.assert_called_once_with(
|
||||
"google", "oauth123", mock_account, session=mock_db_dependencies["db"].session
|
||||
)
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_register_with_pending_status(self, mock_db_dependencies, mock_external_service_dependencies):
|
||||
@ -1707,6 +1792,7 @@ class TestRegisterService:
|
||||
password="password123",
|
||||
language="en-US",
|
||||
status=AccountStatus.PENDING,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
@ -1744,6 +1830,7 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify rollback was called
|
||||
@ -1767,6 +1854,7 @@ class TestRegisterService:
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify rollback was called
|
||||
@ -1810,6 +1898,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
@ -1820,6 +1909,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "newuser@example.com")
|
||||
|
||||
@ -1856,6 +1946,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
mock_register.assert_called_once_with(
|
||||
@ -1864,13 +1955,22 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, mixed_email)
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||
mock_check_permission.assert_called_once_with(
|
||||
mock_tenant,
|
||||
mock_inviter,
|
||||
None,
|
||||
"add",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
mock_create_member.assert_called_once_with(
|
||||
mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal"
|
||||
)
|
||||
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
|
||||
mock_switch_tenant.assert_called_once_with(
|
||||
mock_new_account, mock_tenant.id, session=mock_db_dependencies["db"].session
|
||||
)
|
||||
mock_generate_token.assert_called_once_with(
|
||||
mock_tenant, mock_new_account, "normal", requires_setup=True
|
||||
)
|
||||
@ -1912,6 +2012,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
# Verify results
|
||||
@ -1954,10 +2055,17 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="admin",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
assert result == "invite-token-123"
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, mock_existing_account, "add")
|
||||
mock_check_permission.assert_called_once_with(
|
||||
mock_tenant,
|
||||
mock_inviter,
|
||||
mock_existing_account,
|
||||
"add",
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
mock_create_member.assert_not_called()
|
||||
mock_generate_token.assert_called_once_with(
|
||||
mock_tenant, mock_existing_account, "admin", requires_setup=False
|
||||
@ -1993,6 +2101,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
mock_lookup.assert_called_once()
|
||||
|
||||
@ -2010,6 +2119,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="normal",
|
||||
inviter=None,
|
||||
session=MagicMock(),
|
||||
)
|
||||
|
||||
# ==================== RBAC Member Invitation Tests ====================
|
||||
@ -2048,6 +2158,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="rbac-role-id-123",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
assert result == "rbac-token"
|
||||
@ -2093,6 +2204,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="rbac-role-id-456",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
assert result == "rbac-token"
|
||||
@ -2141,6 +2253,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="rbac-role-id-456",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
mock_create_member.assert_called_once_with(
|
||||
@ -2191,6 +2304,7 @@ class TestRegisterService:
|
||||
language="en-US",
|
||||
role="editor",
|
||||
inviter=mock_inviter,
|
||||
session=mock_db_dependencies["db"].session,
|
||||
)
|
||||
|
||||
assert result == "legacy-token"
|
||||
@ -2300,7 +2414,9 @@ class TestRegisterService:
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account]
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
result = RegisterService.get_invitation_if_token_valid(
|
||||
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result is not None
|
||||
@ -2314,7 +2430,9 @@ class TestRegisterService:
|
||||
mock_redis_dependencies.get.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
result = RegisterService.get_invitation_if_token_valid(
|
||||
"tenant-456", "test@example.com", "token-123", session=MagicMock()
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
@ -2333,7 +2451,9 @@ class TestRegisterService:
|
||||
mock_db_dependencies["db"].session.scalar.return_value = None
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
result = RegisterService.get_invitation_if_token_valid(
|
||||
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
@ -2357,7 +2477,9 @@ class TestRegisterService:
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, None]
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
result = RegisterService.get_invitation_if_token_valid(
|
||||
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
@ -2384,7 +2506,9 @@ class TestRegisterService:
|
||||
mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account]
|
||||
|
||||
# Execute test
|
||||
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
|
||||
result = RegisterService.get_invitation_if_token_valid(
|
||||
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert result is None
|
||||
@ -2395,22 +2519,28 @@ class TestRegisterService:
|
||||
with patch(
|
||||
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
|
||||
) as mock_get:
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
result = RegisterService.get_invitation_with_case_fallback(
|
||||
"tenant-456", "User@Test.com", "token-123", session=MagicMock()
|
||||
)
|
||||
|
||||
assert result == invitation
|
||||
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
|
||||
mock_get.assert_called_once_with(
|
||||
"tenant-456", "User@Test.com", "token-123", session=mock_get.call_args.kwargs["session"]
|
||||
)
|
||||
|
||||
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
|
||||
"""Fallback helper should retry with lowercase email when needed."""
|
||||
invitation = {"workspace_id": "tenant-456"}
|
||||
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
|
||||
mock_get.side_effect = [None, invitation]
|
||||
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
|
||||
result = RegisterService.get_invitation_with_case_fallback(
|
||||
"tenant-456", "User@Test.com", "token-123", session=MagicMock()
|
||||
)
|
||||
|
||||
assert result == invitation
|
||||
assert mock_get.call_args_list == [
|
||||
(("tenant-456", "User@Test.com", "token-123"),),
|
||||
(("tenant-456", "user@test.com", "token-123"),),
|
||||
(("tenant-456", "User@Test.com", "token-123"), {"session": mock_get.call_args_list[0].kwargs["session"]}),
|
||||
(("tenant-456", "user@test.com", "token-123"), {"session": mock_get.call_args_list[1].kwargs["session"]}),
|
||||
]
|
||||
|
||||
# ==================== Helper Method Tests ====================
|
||||
|
||||
@ -546,7 +546,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound):
|
||||
AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1")
|
||||
AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1", mock_db.session)
|
||||
|
||||
def test_update_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None:
|
||||
"""Test missing app raises NotFound in update path."""
|
||||
@ -562,7 +562,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound):
|
||||
AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1")
|
||||
AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1", mock_db.session)
|
||||
|
||||
def test_update_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None:
|
||||
"""Test missing question raises ValueError."""
|
||||
@ -581,7 +581,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id)
|
||||
AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session)
|
||||
|
||||
def test_update_app_annotation_directly_should_update_annotation_and_index(self) -> None:
|
||||
"""Test update changes fields and triggers index update."""
|
||||
@ -602,7 +602,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
mock_db.session.get.return_value = annotation
|
||||
|
||||
# Act
|
||||
result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id)
|
||||
result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session)
|
||||
|
||||
# Assert
|
||||
assert result == annotation
|
||||
@ -640,7 +640,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
mock_db.session.scalars.return_value = scalars_result
|
||||
|
||||
# Act
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation.id)
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation.id, mock_db.session)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_any_call(annotation)
|
||||
@ -667,7 +667,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound):
|
||||
AppAnnotationService.delete_app_annotation("app-1", "ann-1")
|
||||
AppAnnotationService.delete_app_annotation("app-1", "ann-1", mock_db.session)
|
||||
|
||||
def test_delete_app_annotation_should_raise_not_found_when_annotation_missing(self) -> None:
|
||||
"""Test delete raises NotFound when annotation is missing."""
|
||||
@ -684,7 +684,7 @@ class TestAppAnnotationServiceDirectManipulation:
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NotFound):
|
||||
AppAnnotationService.delete_app_annotation(app.id, "ann-1")
|
||||
AppAnnotationService.delete_app_annotation(app.id, "ann-1", mock_db.session)
|
||||
|
||||
def test_delete_app_annotations_in_batch_should_return_zero_when_none_found(self) -> None:
|
||||
"""Test batch delete returns zero when no annotations found."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user