chore: make AccountService.load_user use passed session (#37764)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-24 16:29:12 +09:00 committed by GitHub
parent f665bcac95
commit 2112115962
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 1060 additions and 554 deletions

View File

@ -133,8 +133,9 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
password=new_password,
language=language,
create_workspace_required=False,
session=db.session,
)
TenantService.create_owner_tenant_if_not_exist(account, name)
TenantService.create_owner_tenant_if_not_exist(account, name, session=db.session)
click.echo(
click.style(

View File

@ -19,6 +19,7 @@ from controllers.console.wraps import (
rbac_permission_required,
setup_required,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
Annotation,
@ -388,7 +389,9 @@ class AnnotationUpdateDeleteApi(Resource):
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id))
annotation = AppAnnotationService.update_app_annotation_directly(
update_args, str(app_id), str(annotation_id), db.session
)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@ -398,7 +401,7 @@ class AnnotationUpdateDeleteApi(Resource):
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT)
@console_ns.response(204, "Annotation deleted successfully")
def delete(self, app_id: UUID, annotation_id: UUID):
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id), db.session)
return "", 204

View File

@ -17,6 +17,7 @@ from controllers.console.wraps import (
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import AccountWithRole
from libs.helper import build_avatar_url, dump_response, to_timestamp
@ -489,7 +490,7 @@ class WorkflowCommentMentionUsersApi(Resource):
current_tenant = current_user.current_tenant # need the tenant object here
if current_tenant is None:
raise ValueError("current tenant is required")
members = TenantService.get_tenant_members(current_tenant)
members = TenantService.get_tenant_members(current_tenant, session=db.session)
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = WorkflowCommentMentionUsersPayload(users=users)
return response.model_dump(mode="json"), 200

View File

@ -89,7 +89,9 @@ class ActivateCheckApi(Resource):
workspaceId = args.workspace_id
token = args.token
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
invitation = RegisterService.get_invitation_with_case_fallback(
workspaceId, args.email, token, session=db.session
)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@ -137,7 +139,9 @@ class ActivateApi(Resource):
args = ActivatePayload.model_validate(console_ns.payload)
normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
invitation = RegisterService.get_invitation_with_case_fallback(
args.workspace_id, args.email, args.token, session=db.session
)
if invitation is None:
raise AlreadyActivateError()
@ -184,6 +188,6 @@ class ActivateApi(Resource):
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
TenantService.switch_tenant(account, tenant.id)
TenantService.switch_tenant(account, tenant.id, session=db.session)
return {"result": "success"}

View File

@ -187,7 +187,7 @@ class EmailRegisterResetApi(Resource):
timezone=args.timezone,
language=args.language,
)
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
@ -206,6 +206,7 @@ class EmailRegisterResetApi(Resource):
password=password,
interface_language=get_valid_language(language),
timezone=timezone,
session=db.session,
)
except AccountRegisterError:
raise AccountInFreezeError()

View File

@ -198,10 +198,10 @@ class ForgotPasswordResetApi(Resource):
# Create workspace if needed
if (
not TenantService.get_join_tenants(account)
not TenantService.get_join_tenants(account, session=db.session)
and FeatureService.get_system_features().is_allow_create_workspace
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)

View File

@ -119,7 +119,9 @@ class LoginApi(Resource):
invite_token = args.invite_token
invitation_data: InvitationDetailDict | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
invitation_data = RegisterService.get_invitation_with_case_fallback(
None, request_email, invite_token, session=db.session
)
if invitation_data is None:
invite_token = None
@ -145,7 +147,7 @@ class LoginApi(Resource):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
tenants = TenantService.get_join_tenants(account, session=db.session)
if len(tenants) == 0:
system_features = FeatureService.get_system_features()
@ -157,7 +159,7 @@ class LoginApi(Resource):
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
@ -291,7 +293,7 @@ class EmailCodeLoginApi(Resource):
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
tenants = TenantService.get_join_tenants(account, session=db.session)
if not tenants:
workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available():
@ -299,7 +301,7 @@ class EmailCodeLoginApi(Resource):
if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace()
else:
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)
@ -311,6 +313,7 @@ class EmailCodeLoginApi(Resource):
name=user_email,
interface_language=get_valid_language(language),
timezone=args.timezone,
session=db.session,
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
@ -319,7 +322,7 @@ class EmailCodeLoginApi(Resource):
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
token_pair = AccountService.login(account, session=db.session, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
@ -343,7 +346,7 @@ class RefreshTokenApi(Resource):
return {"result": "fail", "message": "No refresh token provided"}, 401
try:
new_token_pair = AccountService.refresh_token(refresh_token)
new_token_pair = AccountService.refresh_token(refresh_token, session=db.session)
# Create response with new cookies
response = make_response({"result": "success"})
@ -358,22 +361,22 @@ class RefreshTokenApi(Resource):
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
account = AccountService.get_user_through_email(email, session=db.session)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
return AccountService.get_user_through_email(email.lower(), session=db.session)
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
return AccountService.authenticate(original_email, password, invite_token, session=db.session)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)
return AccountService.authenticate(normalized_email, password, invite_token, session=db.session)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:

View File

@ -195,7 +195,7 @@ class OAuthCallback(Resource):
db.session.commit()
try:
TenantService.create_owner_tenant_if_not_exist(account)
TenantService.create_owner_tenant_if_not_exist(account, session=db.session)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
@ -206,6 +206,7 @@ class OAuthCallback(Resource):
token_pair = AccountService.login(
account=account,
session=db.session,
ip_address=extract_remote_ip(request),
)
@ -240,12 +241,12 @@ def _generate_account(
oauth_new_user = False
if account:
tenants = TenantService.get_join_tenants(account)
tenants = TenantService.get_join_tenants(account, session=db.session)
if not tenants:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError()
else:
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)
@ -272,9 +273,10 @@ def _generate_account(
provider=provider,
language=interface_language,
timezone=timezone,
session=db.session,
)
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
AccountService.link_account_integrate(provider, user_info.id, account, session=db.session)
return account, oauth_new_user

View File

@ -175,7 +175,7 @@ class InstalledAppsListApi(Resource):
if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None")
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant, session=db.session)
installed_app_list: list[dict[str, Any]] = []
for installed_app, app_model in installed_apps:
installed_app_list.append(

View File

@ -50,7 +50,7 @@ def get_init_status() -> InitStatusResponse:
@only_edition_self_hosted
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
"""Validate initialization password."""
tenant_count = TenantService.get_tenant_count()
tenant_count = TenantService.get_tenant_count(session=db.session)
if tenant_count > 0:
raise AlreadySetupError()

View File

@ -79,7 +79,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
if get_setup_status():
raise AlreadySetupError()
tenant_count = TenantService.get_tenant_count()
tenant_count = TenantService.get_tenant_count(session=db.session)
if tenant_count > 0:
raise AlreadySetupError()
@ -94,6 +94,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
password=payload.password,
ip_address=extract_remote_ip(request),
language=payload.language,
session=db.session,
)
return SetupResponse(result="success")

View File

@ -4,6 +4,7 @@ from typing import cast
from flask import Request as FlaskRequest
from extensions.ext_database import db
from extensions.ext_socketio import sio
from libs.passport import PassportService
from libs.token import extract_access_token
@ -43,7 +44,7 @@ def socket_connect(sid, environ, auth):
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
user = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
if not user:
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
return False

View File

@ -328,7 +328,7 @@ class AccountNameApi(Resource):
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
updated_account = AccountService.update_account(current_user, session=db.session, name=args.name)
return _serialize_account(updated_account)
@ -374,7 +374,7 @@ class AccountAvatarApi(Resource):
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
updated_account = AccountService.update_account(current_user, session=db.session, avatar=args.avatar)
return _serialize_account(updated_account)
@ -391,7 +391,9 @@ class AccountInterfaceLanguageApi(Resource):
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
updated_account = AccountService.update_account(
current_user, session=db.session, interface_language=args.interface_language
)
return _serialize_account(updated_account)
@ -408,7 +410,9 @@ class AccountInterfaceThemeApi(Resource):
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
updated_account = AccountService.update_account(
current_user, session=db.session, interface_theme=args.interface_theme
)
return _serialize_account(updated_account)
@ -425,7 +429,7 @@ class AccountTimezoneApi(Resource):
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
updated_account = AccountService.update_account(current_user, session=db.session, timezone=args.timezone)
return _serialize_account(updated_account)
@ -443,7 +447,8 @@ class AccountPasswordApi(Resource):
args = AccountPasswordPayload.model_validate(payload)
try:
AccountService.update_account_password(current_user, args.password, args.new_password)
assert args.password is not None
AccountService.update_account_password(current_user, args.password, args.new_password, session=db.session)
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@ -731,7 +736,7 @@ class ChangeEmailResetApi(Resource):
if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_new_email):
if not AccountService.check_email_unique(normalized_new_email, session=db.session):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@ -755,7 +760,9 @@ class ChangeEmailResetApi(Resource):
# legitimately verified token.
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
updated_account = AccountService.update_account_email(
current_user, email=normalized_new_email, session=db.session
)
AccountService.send_change_email_completed_notify_email(
email=normalized_new_email,
@ -775,6 +782,6 @@ class CheckEmailUnique(Resource):
normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_email):
if not AccountService.check_email_unique(normalized_email, session=db.session):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@ -186,7 +186,7 @@ class MemberListApi(Resource):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
members = TenantService.get_tenant_members(current_user.current_tenant, session=db.session)
if dify_config.RBAC_ENABLED:
member_ids = [member.id for member in members]
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
@ -273,6 +273,7 @@ class MemberInviteEmailApi(Resource):
language=interface_language,
role=invitee_role,
inviter=inviter,
session=db.session,
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
@ -317,7 +318,9 @@ class MemberCancelInviteApi(Resource):
abort(404)
else:
try:
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
TenantService.remove_member_from_tenant(
current_user.current_tenant, member, current_user, session=db.session
)
except services.errors.account.CannotOperateSelfError as e:
return {"code": "cannot-operate-self", "message": str(e)}, 400
except services.errors.account.NoPermissionError as e:
@ -360,7 +363,9 @@ class MemberUpdateRoleApi(Resource):
try:
assert member is not None, "Member not found"
TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user)
TenantService.update_member_role(
current_user.current_tenant, member, new_role, current_user, session=db.session
)
except services.errors.account.CannotOperateSelfError as e:
return {"code": "cannot-operate-self", "message": str(e)}, 400
except services.errors.account.NoPermissionError as e:
@ -387,7 +392,7 @@ class DatasetOperatorMemberListApi(Resource):
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
members = TenantService.get_dataset_operator_members(current_user.current_tenant, session=db.session)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@ -413,7 +418,7 @@ class SendOwnerTransferEmailApi(Resource):
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
raise NotOwnerError()
if args.language is not None and args.language == "zh-Hans":
@ -448,7 +453,7 @@ class OwnerTransferCheckApi(Resource):
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
raise NotOwnerError()
user_email = current_user.email
@ -494,7 +499,7 @@ class OwnerTransfer(Resource):
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
raise NotOwnerError()
if current_user.id == str(member_id):
@ -516,12 +521,14 @@ class OwnerTransfer(Resource):
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_member(member, current_user.current_tenant):
if not TenantService.is_member(member, current_user.current_tenant, session=db.session):
raise MemberNotInTenantError()
try:
assert member is not None, "Member not found"
TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user)
TenantService.update_member_role(
current_user.current_tenant, member, "owner", current_user, session=db.session
)
AccountService.send_new_owner_transfer_notify_email(
account=member,

View File

@ -325,10 +325,10 @@ class TenantApi(Resource):
raise ValueError("No current tenant")
if tenant.status == TenantStatus.ARCHIVE:
tenants = TenantService.get_join_tenants(current_user)
tenants = TenantService.get_join_tenants(current_user, session=db.session)
# if there is any tenant, switch to the first one
if len(tenants) > 0:
TenantService.switch_tenant(current_user, tenants[0].id)
TenantService.switch_tenant(current_user, tenants[0].id, session=db.session)
tenant = tenants[0]
# else, raise Unauthorized
else:
@ -351,7 +351,7 @@ class SwitchWorkspaceApi(Resource):
# check if tenant_id is valid, 403 if not
try:
TenantService.switch_tenant(current_user, args.tenant_id)
TenantService.switch_tenant(current_user, args.tenant_id, session=db.session)
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")

View File

@ -47,7 +47,7 @@ class EnterpriseWorkspace(Resource):
if account is None:
return {"message": "owner account not found."}, 404
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
tenant_was_created.send(tenant)
@ -84,7 +84,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
def post(self):
args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
tenant_was_created.send(tenant)

View File

@ -128,7 +128,7 @@ class WorkspaceSwitchApi(Resource):
account = _load_account(auth_data.account_id)
try:
TenantService.switch_tenant(account, workspace_id)
TenantService.switch_tenant(account, workspace_id, session=db.session)
except AccountNotLinkTenantError:
raise NotFound("workspace not found")
@ -152,7 +152,7 @@ class WorkspaceMembersApi(Resource):
@accepts(query=MemberListQuery)
def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery):
tenant = _load_tenant(workspace_id)
members = TenantService.get_tenant_members(tenant)
members = TenantService.get_tenant_members(tenant, session=db.session)
total = len(members)
start = (query.page - 1) * query.limit
page_items = members[start : start + query.limit]
@ -184,6 +184,7 @@ class WorkspaceMembersApi(Resource):
language=None,
role=body.role,
inviter=inviter,
session=db.session,
)
except AccountAlreadyInTenantError as exc:
raise BadRequest(str(exc))
@ -232,7 +233,7 @@ class WorkspaceMemberApi(Resource):
raise NotFound("member not found")
try:
TenantService.remove_member_from_tenant(tenant, member, operator)
TenantService.remove_member_from_tenant(tenant, member, operator, session=db.session)
except CannotOperateSelfError as exc:
raise BadRequest(str(exc))
except NoPermissionError as exc:
@ -266,7 +267,7 @@ class WorkspaceMemberRoleApi(Resource):
raise NotFound("member not found")
try:
TenantService.update_member_role(tenant, member, body.role, operator)
TenantService.update_member_role(tenant, member, body.role, operator, session=db.session)
except CannotOperateSelfError as exc:
raise BadRequest(str(exc))
except NoPermissionError as exc:

View File

@ -10,6 +10,7 @@ from controllers.common.schema import query_params_from_model, register_response
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from fields.base import ResponseModel
@ -281,7 +282,9 @@ class AnnotationUpdateDeleteApi(Resource):
"""Update an existing annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
annotation = AppAnnotationService.update_app_annotation_directly(
update_args, app_model.id, str(annotation_id), db.session
)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@ -310,5 +313,5 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def delete(self, app_model: App, annotation_id: UUID):
"""Delete an annotation."""
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id), db.session)
return "", 204

View File

@ -226,7 +226,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_multimodal_documents.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
account = AccountService.load_user(document.created_by, db.session)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)

View File

@ -291,7 +291,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
attachments.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
account = AccountService.load_user(document.created_by, db.session)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)

View File

@ -84,7 +84,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
if not user_id:
raise Unauthorized("Invalid Authorization token.")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
return logged_in_account
elif request.blueprint == "openapi":
# Account-branch device-flow approval routes (approve / deny /
@ -103,7 +103,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
source = decoded.get("token_source")
if source or not user_id:
return None
return AccountService.load_logged_in_account(account_id=user_id)
return AccountService.load_logged_in_account(account_id=user_id, session=db.session)
elif request.blueprint == "web":
app_code = request.headers.get(HEADER_NAME_APP_CODE)
webapp_token = extract_webapp_passport(app_code, request) if app_code else None

View File

@ -16,7 +16,7 @@ def valid_password(password):
raise ValueError("Password must contain letters and numbers, and the length must be at least 8 characters.")
def hash_password(password_str, salt_byte):
def hash_password(password_str: str, salt_byte: bytes):
dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
return binascii.hexlify(dk)

View File

@ -1,3 +1,10 @@
"""Account, workspace, and invitation services.
Database access in this module is caller-scoped: methods that read or mutate ORM state accept an explicit
``session`` so controllers, tasks, and tests can control transaction lifetime and avoid hidden Flask-scoped session
usage inside service logic.
"""
import base64
import json
import logging
@ -235,7 +242,7 @@ class AccountService:
)
@staticmethod
def _refresh_account_last_active(account: Account) -> None:
def _refresh_account_last_active(account: Account, session: scoped_session | Session) -> None:
now = naive_utc_now()
refresh_before = now - ACCOUNT_LAST_ACTIVE_REFRESH_INTERVAL
@ -245,12 +252,12 @@ class AccountService:
if not AccountService._should_refresh_account_last_active(account.id):
return
db.session.execute(
session.execute(
update(Account)
.where(Account.id == account.id, Account.last_active_at < refresh_before)
.values(last_active_at=now, updated_at=func.current_timestamp())
)
db.session.commit()
session.commit()
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str):
@ -295,20 +302,20 @@ class AccountService:
side-effects (current-tenant assignment, commit) are unwanted.
``session`` is injected by the caller so this service stays free
of the Flask-scoped ``db.session`` import.
of a Flask-scoped session import.
"""
return session.get(Account, account_id)
@staticmethod
def load_user(user_id: str) -> None | Account:
account = db.session.get(Account, user_id)
def load_user(user_id: str, session: scoped_session | Session) -> None | Account:
account = session.get(Account, user_id)
if not account:
return None
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.scalar(
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
@ -316,7 +323,7 @@ class AccountService:
if current_tenant:
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = db.session.scalar(
available_ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
@ -328,13 +335,13 @@ class AccountService:
account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True
available_ta.last_opened_at = naive_utc_now()
db.session.commit()
session.commit()
AccountService._refresh_account_last_active(account)
AccountService._refresh_account_last_active(account, session)
# NOTE: make sure account is accessible outside of a db session
# This ensures that it will work correctly after upgrading to Flask version 3.1.2
db.session.refresh(account)
db.session.close()
session.refresh(account)
session.close()
return account
@staticmethod
@ -352,10 +359,12 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
def authenticate(
email: str, password: str, invite_token: str | None = None, *, session: scoped_session | Session
) -> Account:
"""authenticate account with email and password"""
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
account = session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
raise AccountPasswordError("Invalid email or password.")
@ -378,12 +387,14 @@ class AccountService:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()
session.commit()
return account
@staticmethod
def update_account_password(account, password, new_password):
def update_account_password(
account: Account, password: str, new_password: str, *, session: scoped_session | Session
):
"""update account password"""
if account.password and not compare_password(password, account.password, account.password_salt):
raise CurrentPasswordIncorrectError("Current password is incorrect.")
@ -400,8 +411,8 @@ class AccountService:
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.add(account)
db.session.commit()
session.add(account)
session.commit()
return account
@staticmethod
@ -413,6 +424,8 @@ class AccountService:
interface_theme: str = "light",
is_setup: bool | None = False,
timezone: str | None = None,
*,
session: scoped_session | Session,
) -> Account:
"""Create an account, preferring explicit user timezone over language-derived defaults."""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
@ -458,13 +471,19 @@ class AccountService:
timezone=resolved_timezone,
)
db.session.add(account)
db.session.commit()
session.add(account)
session.commit()
return account
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
email: str,
name: str,
interface_language: str,
password: str | None = None,
timezone: str | None = None,
*,
session: scoped_session | Session,
) -> Account:
"""Create an account and owner workspace."""
account = AccountService.create_account(
@ -473,10 +492,11 @@ class AccountService:
interface_language=interface_language,
password=password,
timezone=timezone,
session=session,
)
try:
TenantService.create_owner_tenant_if_not_exist(account=account)
TenantService.create_owner_tenant_if_not_exist(account=account, session=session)
except Exception:
# Enterprise-only side-effect should run independently from personal workspace creation.
_try_join_enterprise_default_workspace(str(account.id))
@ -536,11 +556,11 @@ class AccountService:
delete_account_task.delay(account.id)
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account):
def link_account_integrate(provider: str, open_id: str, account: Account, *, session: scoped_session | Session):
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = db.session.scalar(
account_integrate: AccountIntegrate | None = session.scalar(
select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
@ -556,62 +576,62 @@ class AccountService:
account_integrate = AccountIntegrate(
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
)
db.session.add(account_integrate)
session.add(account_integrate)
db.session.commit()
session.commit()
logger.info("Account %s linked %s account %s.", account.id, provider, open_id)
except Exception as e:
logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
def close_account(account: Account):
def close_account(account: Account, *, session: scoped_session | Session):
"""Close account"""
account.status = AccountStatus.CLOSED
db.session.commit()
session.commit()
@staticmethod
def update_account(account, **kwargs):
def update_account(account: Account, *, session: scoped_session | Session, **kwargs):
"""Update account fields"""
account = db.session.merge(account)
account = session.merge(account)
for field, value in kwargs.items():
if hasattr(account, field):
setattr(account, field, value)
else:
raise AttributeError(f"Invalid field: {field}")
db.session.commit()
session.commit()
return account
@staticmethod
def update_account_email(account: Account, email: str) -> Account:
def update_account_email(account: Account, email: str, session: scoped_session | Session) -> Account:
"""Update account email"""
account.email = email
account_integrate = db.session.scalar(
account_integrate = session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate:
db.session.delete(account_integrate)
db.session.add(account)
db.session.commit()
session.delete(account_integrate)
session.add(account)
session.commit()
return account
@staticmethod
def update_login_info(account: Account, *, ip_address: str):
def update_login_info(account: Account, session: scoped_session | Session, *, ip_address: str):
"""Update last login time and ip"""
account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()
session.add(account)
session.commit()
@staticmethod
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
def login(account: Account, *, session: scoped_session | Session, ip_address: str | None = None) -> TokenPair:
if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address)
AccountService.update_login_info(account=account, session=session, ip_address=ip_address)
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
db.session.commit()
session.commit()
access_token = AccountService.get_account_jwt_token(account=account)
refresh_token = _generate_refresh_token()
@ -628,13 +648,13 @@ class AccountService:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
@staticmethod
def refresh_token(refresh_token: str) -> TokenPair:
def refresh_token(refresh_token: str, *, session: scoped_session | Session) -> TokenPair:
# Verify the refresh token
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
if not account_id:
raise ValueError("Invalid refresh token")
account = AccountService.load_user(account_id.decode("utf-8"))
account = AccountService.load_user(account_id.decode("utf-8"), session)
if not account:
raise ValueError("Invalid account")
@ -649,8 +669,8 @@ class AccountService:
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token)
@staticmethod
def load_logged_in_account(*, account_id: str):
return AccountService.load_user(account_id)
def load_logged_in_account(*, account_id: str, session: scoped_session | Session):
return AccountService.load_user(account_id, session)
@classmethod
def send_reset_password_email(
@ -1002,7 +1022,7 @@ class AccountService:
TokenManager.revoke_token(token, "email_code_login")
@classmethod
def get_user_through_email(cls, email: str):
def get_user_through_email(cls, email: str, *, session: scoped_session | Session):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
raise AccountRegisterError(
description=(
@ -1011,7 +1031,7 @@ class AccountService:
)
)
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
account = session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
return None
@ -1210,13 +1230,19 @@ class AccountService:
return False
@staticmethod
def check_email_unique(email: str) -> bool:
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
def check_email_unique(email: str, *, session: scoped_session | Session) -> bool:
return session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
def create_tenant(
name: str,
is_setup: bool | None = False,
is_from_dashboard: bool | None = False,
*,
session: scoped_session | Session,
) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
@ -1228,8 +1254,8 @@ class TenantService:
raise NotAllowedCreateWorkspace()
tenant = Tenant(name=name)
db.session.add(tenant)
db.session.commit()
session.add(tenant)
session.commit()
for category in TenantPluginAutoUpgradeStrategy.PluginCategory:
plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy(
@ -1241,11 +1267,11 @@ class TenantService:
exclude_plugins=[],
include_plugins=[],
)
db.session.add(plugin_upgrade_strategy)
db.session.commit()
session.add(plugin_upgrade_strategy)
session.commit()
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
session.commit()
from services.credit_pool_service import CreditPoolService
@ -1254,9 +1280,11 @@ class TenantService:
return tenant
@staticmethod
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
def create_owner_tenant_if_not_exist(
account: Account, name: str | None = None, is_setup: bool | None = False, *, session: scoped_session | Session
):
"""Check if user have a workspace or not"""
available_ta = db.session.scalar(
available_ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
@ -1275,10 +1303,10 @@ class TenantService:
raise WorkspacesLimitExceededError()
if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
tenant = TenantService.create_tenant(name=name, is_setup=is_setup, session=session)
else:
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup, session=session)
TenantService.create_tenant_member(tenant, account, session, role="owner")
if dify_config.RBAC_ENABLED:
owner_role_id = AccountService._resolve_legacy_role_id(str(tenant.id), account.id, TenantAccountRole.OWNER)
RBACService.MemberRoles.replace(
@ -1288,16 +1316,16 @@ class TenantService:
role_ids=[owner_role_id],
)
account.current_tenant = tenant
db.session.commit()
session.commit()
tenant_was_created.send(tenant)
@staticmethod
def create_tenant_member(
tenant: Tenant, account: Account, session: scoped_session, role: str = "normal"
tenant: Tenant, account: Account, session: scoped_session | Session, role: str = "normal"
) -> TenantAccountJoin:
"""Create tenant member"""
if role == TenantAccountRole.OWNER:
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=session):
logger.error("Tenant %s has already an owner.", tenant.id)
raise Exception("Tenant already has an owner.")
@ -1318,10 +1346,10 @@ class TenantService:
return ta
@staticmethod
def get_join_tenants(account: Account) -> list[Tenant]:
def get_join_tenants(account: Account, *, session: scoped_session | Session) -> list[Tenant]:
"""Get account join tenants"""
return list(
db.session.scalars(
session.scalars(
select(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
@ -1340,7 +1368,7 @@ class TenantService:
membership + pick the default workspace.
``session`` is injected by the caller so this service stays free
of the Flask-scoped ``db.session`` import.
of a Flask-scoped session import.
No tenant-status filter: parity with the legacy controller query
(the openapi identity endpoint listed all joined tenants).
@ -1413,7 +1441,7 @@ class TenantService:
bearers (no account) collapse to the non-member path. Mirrors the
session-injection style of :meth:`account_belongs_to_tenant` rather
than :meth:`get_user_role`, which loads full ``Account``/``Tenant``
objects against the Flask-scoped ``db.session``.
objects against the Flask-scoped session.
"""
if not account_id:
return None
@ -1479,13 +1507,13 @@ class TenantService:
).first()
@staticmethod
def get_current_tenant_by_account(account: Account):
def get_current_tenant_by_account(account: Account, *, session: scoped_session | Session):
"""Get tenant by account and add the role"""
tenant = account.current_tenant
if not tenant:
raise TenantNotFoundError("Tenant not found.")
ta = db.session.scalar(
ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1497,14 +1525,14 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: str | None = None):
def switch_tenant(account: Account, tenant_id: str | None = None, *, session: scoped_session | Session):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
if tenant_id is None:
raise ValueError("Tenant ID must be provided.")
tenant_account_join = db.session.scalar(
tenant_account_join = session.scalar(
select(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.where(
@ -1518,7 +1546,7 @@ class TenantService:
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
db.session.execute(
session.execute(
update(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
.values(current=False)
@ -1527,10 +1555,10 @@ class TenantService:
tenant_account_join.last_opened_at = naive_utc_now()
# Set the current tenant for the account
account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit()
session.commit()
@staticmethod
def get_tenant_members(tenant: Tenant) -> list[Account]:
def get_tenant_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
"""Get tenant members"""
stmt = (
select(Account, TenantAccountJoin.role)
@ -1542,14 +1570,14 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in db.session.execute(stmt):
for account, role in session.execute(stmt):
account.role = role
updated_accounts.append(account)
return updated_accounts
@staticmethod
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
def get_dataset_operator_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
"""Get dataset admin members"""
stmt = (
select(Account, TenantAccountJoin.role)
@ -1562,20 +1590,20 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in db.session.execute(stmt):
for account, role in session.execute(stmt):
account.role = role
updated_accounts.append(account)
return updated_accounts
@staticmethod
def has_roles(tenant: Tenant, roles: list[TenantAccountRole]) -> bool:
def has_roles(tenant: Tenant, roles: list[TenantAccountRole], *, session: scoped_session | Session) -> bool:
"""Check if user has any of the given roles for a tenant"""
if not all(isinstance(role, TenantAccountRole) for role in roles):
raise ValueError("all roles must be TenantAccountRole")
return (
db.session.scalar(
session.scalar(
select(TenantAccountJoin)
.where(
TenantAccountJoin.tenant_id == tenant.id,
@ -1587,9 +1615,11 @@ class TenantService:
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
def get_user_role(
account: Account, tenant: Tenant, *, session: scoped_session | Session
) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = db.session.scalar(
join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1597,12 +1627,14 @@ class TenantService:
return TenantAccountRole(join.role) if join else None
@staticmethod
def get_tenant_count() -> int:
def get_tenant_count(*, session: scoped_session | Session) -> int:
"""Get tenant count"""
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
return cast(int, session.scalar(select(func.count(Tenant.id))))
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
def check_member_permission(
tenant: Tenant, operator: Account, member: Account | None, action: str, *, session: scoped_session | Session
):
"""Check member permission"""
if action not in {"add", "remove", "update"}:
raise InvalidActionError("Invalid action.")
@ -1636,7 +1668,7 @@ class TenantService:
"update": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
}
ta_operator = db.session.scalar(
ta_operator = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
.limit(1)
@ -1646,7 +1678,7 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
if action == "remove" and ta_operator.role == TenantAccountRole.ADMIN and member:
ta_member = db.session.scalar(
ta_member = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
.limit(1)
@ -1655,7 +1687,9 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
def remove_member_from_tenant(
tenant: Tenant, account: Account, operator: Account, *, session: scoped_session | Session
):
"""Remove member from tenant.
Apps and datasets maintained by the removed member are reassigned to
@ -1667,9 +1701,9 @@ class TenantService:
if operator.id == account.id:
raise CannotOperateSelfError("Cannot operate self.")
TenantService.check_member_permission(tenant, operator, account, "remove")
TenantService.check_member_permission(tenant, operator, account, "remove", session=session)
ta = db.session.scalar(
ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1686,7 +1720,7 @@ class TenantService:
if dify_config.RBAC_ENABLED:
owner_id = AccountService.get_rbac_workspace_owner_account_id(str(tenant.id), str(operator.id))
else:
owner_id = db.session.scalar(
owner_id = session.scalar(
select(TenantAccountJoin.account_id)
.where(
TenantAccountJoin.tenant_id == tenant.id,
@ -1697,7 +1731,7 @@ class TenantService:
if owner_id is None:
raise ValueError(f"Workspace owner not found for tenant {tenant.id}.")
db.session.execute(
session.execute(
update(App)
.where(
App.tenant_id == tenant.id,
@ -1705,7 +1739,7 @@ class TenantService:
)
.values(maintainer=owner_id)
)
db.session.execute(
session.execute(
update(Dataset)
.where(
Dataset.tenant_id == tenant.id,
@ -1713,23 +1747,23 @@ class TenantService:
)
.values(maintainer=owner_id)
)
db.session.delete(ta)
session.delete(ta)
# Clean up orphaned pending accounts (invited but never activated)
should_delete_account = False
if account.status == AccountStatus.PENDING:
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
remaining_joins = (
db.session.scalar(
session.scalar(
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
)
or 0
)
if remaining_joins == 0:
db.session.delete(account)
session.delete(account)
should_delete_account = True
db.session.commit()
session.commit()
if should_delete_account:
logger.info(
@ -1755,12 +1789,14 @@ class TenantService:
)
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
def update_member_role(
tenant: Tenant, member: Account, new_role: str, operator: Account, *, session: scoped_session | Session
):
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
TenantService.check_member_permission(tenant, operator, member, "update", session=session)
new_tenant_role = TenantAccountRole(new_role)
target_member_join = db.session.scalar(
target_member_join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
.limit(1)
@ -1769,7 +1805,7 @@ class TenantService:
if not target_member_join:
raise MemberNotInTenantError("Member not in tenant.")
operator_role = TenantService.get_user_role(operator, tenant)
operator_role = TenantService.get_user_role(operator, tenant, session=session)
target_role = TenantAccountRole(target_member_join.role)
if operator_role == TenantAccountRole.ADMIN and (TenantAccountRole.OWNER in {target_role, new_tenant_role}):
raise NoPermissionError("No permission to update member.")
@ -1779,7 +1815,7 @@ class TenantService:
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = db.session.scalar(
current_owner_join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
@ -1815,7 +1851,7 @@ class TenantService:
)
else:
target_member_join.role = new_tenant_role
db.session.commit()
session.commit()
@staticmethod
def get_custom_config(tenant_id: str):
@ -1824,13 +1860,13 @@ class TenantService:
return tenant.custom_config_dict
@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:
return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER
def is_owner(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
return TenantService.get_user_role(account, tenant, session=session) == TenantAccountRole.OWNER
@staticmethod
def is_member(account: Account, tenant: Tenant) -> bool:
def is_member(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
"""Check if the account is a member of the tenant"""
return TenantService.get_user_role(account, tenant) is not None
return TenantService.get_user_role(account, tenant, session=session) is not None
class RegisterService:
@ -1839,7 +1875,16 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
def setup(
cls,
email: str,
name: str,
password: str,
ip_address: str,
language: str | None,
*,
session: scoped_session | Session,
):
"""
Setup dify
@ -1856,22 +1901,23 @@ class RegisterService:
interface_language=get_valid_language(language),
password=password,
is_setup=True,
session=session,
)
account.last_login_ip = ip_address
account.initialized_at = naive_utc_now()
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True, session=session)
dify_setup = DifySetup(version=dify_config.project.version)
db.session.add(dify_setup)
db.session.commit()
session.add(dify_setup)
session.commit()
except Exception as e:
db.session.execute(delete(DifySetup))
db.session.execute(delete(TenantAccountJoin))
db.session.execute(delete(Account))
db.session.execute(delete(Tenant))
db.session.commit()
session.execute(delete(DifySetup))
session.execute(delete(TenantAccountJoin))
session.execute(delete(Account))
session.execute(delete(Tenant))
session.commit()
logger.exception("Setup account failed, email: %s, name: %s", email, name)
raise ValueError(f"Setup failed: {e}")
@ -1889,9 +1935,11 @@ class RegisterService:
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
timezone: str | None = None,
*,
session: scoped_session | Session,
) -> Account:
"""Register account"""
db.session.begin_nested()
session.begin_nested()
try:
interface_language = get_valid_language(language)
account = AccountService.create_account(
@ -1901,12 +1949,13 @@ class RegisterService:
password=password,
is_setup=is_setup,
timezone=timezone,
session=session,
)
account.status = status or AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
AccountService.link_account_integrate(provider, open_id, account, session=session)
if (
FeatureService.get_system_features().is_allow_create_workspace
@ -1914,27 +1963,27 @@ class RegisterService:
and FeatureService.get_system_features().license.workspaces.is_available()
):
try:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=session)
TenantService.create_tenant_member(tenant, account, session, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
except Exception:
_try_join_enterprise_default_workspace(str(account.id))
raise
db.session.commit()
session.commit()
_try_join_enterprise_default_workspace(str(account.id))
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
session.rollback()
logger.exception("Register failed")
raise AccountRegisterError("Workspace is not allowed to create.")
except AccountRegisterError as are:
db.session.rollback()
session.rollback()
logger.exception("Register failed")
raise are
except Exception as e:
db.session.rollback()
session.rollback()
logger.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e
@ -1942,7 +1991,14 @@ class RegisterService:
@classmethod
def invite_new_member(
cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None
cls,
tenant: Tenant,
email: str,
language: str | None,
role: str = "normal",
inviter: Account | None = None,
*,
session: scoped_session | Session,
) -> str:
if not inviter:
raise ValueError("Inviter is required")
@ -1960,7 +2016,7 @@ class RegisterService:
requires_setup = False
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
TenantService.check_member_permission(tenant, inviter, None, "add", session=session)
name = normalized_email.split("@")[0]
account = cls.register(
@ -1969,13 +2025,14 @@ class RegisterService:
language=language,
status=AccountStatus.PENDING,
is_setup=True,
session=session,
)
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
TenantService.switch_tenant(account, tenant.id)
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
TenantService.switch_tenant(account, tenant.id, session=session)
requires_setup = True
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = db.session.scalar(
TenantService.check_member_permission(tenant, inviter, account, "add", session=session)
ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1983,7 +2040,7 @@ class RegisterService:
requires_setup = account.status == AccountStatus.PENDING
if not ta and (account.status == AccountStatus.PENDING or dify_config.RBAC_ENABLED):
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
# Support resend invitation email when the account is pending status
if account.status != AccountStatus.PENDING:
@ -2052,20 +2109,20 @@ class RegisterService:
@classmethod
def get_invitation_if_token_valid(
cls, workspace_id: str | None, email: str | None, token: str
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
) -> InvitationDetailDict | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
tenant = db.session.scalar(
tenant = session.scalar(
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
)
if not tenant:
return None
account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
account = session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
if not account:
return None
@ -2105,13 +2162,13 @@ class RegisterService:
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
) -> InvitationDetailDict | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token, session=session)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token, session=session)
def _generate_refresh_token(length: int = 64):

View File

@ -4,6 +4,7 @@ from typing import TypedDict
import pandas as pd
from sqlalchemy import delete, or_, select, update
from sqlalchemy.orm import scoped_session
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -300,17 +301,19 @@ class AppAnnotationService:
return annotation
@classmethod
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
def update_app_annotation_directly(
cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str, session: scoped_session
):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
app = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.get(MessageAnnotation, annotation_id)
annotation = session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@ -326,9 +329,9 @@ class AppAnnotationService:
annotation.content = answer
annotation.question = question
db.session.commit()
session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = db.session.scalar(
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
@ -344,33 +347,33 @@ class AppAnnotationService:
return annotation
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
def delete_app_annotation(cls, app_id: str, annotation_id: str, session: scoped_session):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
app = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.get(MessageAnnotation, annotation_id)
annotation = session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
db.session.delete(annotation)
session.delete(annotation)
annotation_hit_histories = db.session.scalars(
annotation_hit_histories = session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
session.delete(annotation_hit_history)
db.session.commit()
session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = db.session.scalar(
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)

View File

@ -91,4 +91,4 @@ class OAuthServerService:
user_id_str = user_account_id.decode("utf-8")
return AccountService.load_user(user_id_str)
return AccountService.load_user(user_id_str, db.session)

View File

@ -36,7 +36,9 @@ class WorkspaceService:
feature = FeatureService.get_features(tenant.id, exclude_vector_space=True)
can_replace_logo = feature.can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
if can_replace_logo and TenantService.has_roles(
tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], session=db.session
):
base_url = dify_config.FILES_URL
replace_webapp_logo = (
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"

View File

@ -84,6 +84,7 @@ def setup_account(request) -> Generator[Account, None, None]:
password=secrets.token_hex(16),
ip_address="localhost",
language="en-US",
session=db.session,
)
with _CACHED_APP.test_request_context():

View File

@ -548,6 +548,7 @@ class TestAccountGeneration:
provider="github",
language="en-US",
timezone=None,
session=ANY,
)
else:
mock_register_service.register.assert_not_called()
@ -581,6 +582,7 @@ class TestAccountGeneration:
provider="github",
language="en-US",
timezone=None,
session=ANY,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@ -612,6 +614,7 @@ class TestAccountGeneration:
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
session=ANY,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@ -643,6 +646,7 @@ class TestAccountGeneration:
provider="github",
language="zh-Hans",
timezone=None,
session=ANY,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@ -673,7 +677,7 @@ class TestAccountGeneration:
assert result == mock_account
assert oauth_new_user is False
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace", session=ANY)
mock_tenant_service.create_tenant_member.assert_called_once_with(
mock_new_tenant, mock_account, ANY, role="owner"
)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import uuid
from collections.abc import Callable, Iterator
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Literal
from unittest.mock import patch
@ -12,7 +12,6 @@ from flask import Flask
from sqlalchemy.orm import Session
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
from models import Account, Tenant
from services.account_service import AccountService, TenantService
@ -46,20 +45,25 @@ def make_account(db_session_with_containers: Session) -> Callable[..., Account]:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
if with_owner_tenant:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(
account, name=fake.company(), session=db_session_with_containers
)
return account
return _make
def add_tenant_for_account(account: Account, *, role: str = "normal", name: str = "Second WS") -> Tenant:
def add_tenant_for_account(
account: Account, *, session: Session, role: str = "normal", name: str = "Second WS"
) -> Tenant:
"""Create an additional tenant and join ``account`` to it (real service calls)."""
with patch("services.account_service.FeatureService") as mock_feature_service:
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
tenant = TenantService.create_tenant(name=name)
TenantService.create_tenant_member(tenant, account, db.session, role=role)
tenant = TenantService.create_tenant(name=name, session=session)
TenantService.create_tenant_member(tenant, account, session, role=role)
return tenant
@ -93,7 +97,7 @@ def account_auth_context(
*,
token_id: uuid.UUID,
client_id: str = "integration-cli",
) -> Iterator[AuthContext]:
) -> Generator[AuthContext]:
"""Publish an account ``AuthContext`` for handlers that read ``get_auth_ctx()``.
The auth pipeline normally sets this ContextVar; the integration suite

View File

@ -4,6 +4,7 @@ from collections.abc import Callable
from inspect import unwrap
from flask import Flask
from sqlalchemy.orm import Session
from controllers.openapi.account import AccountApi
from models import Account
@ -34,11 +35,13 @@ class TestAccountInfo:
# the only workspace the account belongs to.
assert result.default_workspace_id == owner_tenant.id
def test_lists_all_joined_workspaces(self, app: Flask, make_account: Callable[..., Account]) -> None:
def test_lists_all_joined_workspaces(
self, app: Flask, db_session_with_containers: Session, make_account: Callable[..., Account]
) -> None:
account = make_account()
owner_tenant = account.current_tenant
assert owner_tenant is not None
second = add_tenant_for_account(account, role="normal", name="Second WS")
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
api = AccountApi()
with app.test_request_context("/openapi/v1/account"):

View File

@ -81,8 +81,9 @@ def _app_and_account(db_session: Session, *, mode: str = "chat") -> tuple[App, A
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session)
tenant = account.current_tenant
assert tenant is not None
app_args = CreateAppParams(

View File

@ -41,7 +41,7 @@ class TestWorkspacesList:
account = make_account()
owner_tenant = account.current_tenant
assert owner_tenant is not None
second = add_tenant_for_account(account, role="normal", name="Second WS")
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
api = WorkspacesApi()
with app.test_request_context("/openapi/v1/workspaces"):
@ -90,7 +90,9 @@ class TestWorkspaceSwitch:
account = make_account()
owner_tenant = account.current_tenant
assert owner_tenant is not None
target = add_tenant_for_account(account, role="normal", name="Switch Target")
target = add_tenant_for_account(
account, session=db_session_with_containers, role="normal", name="Switch Target"
)
api = WorkspaceSwitchApi()
with app.test_request_context(f"/openapi/v1/workspaces/{target.id}/switch", method="POST"):

View File

@ -27,8 +27,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -88,8 +89,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -141,8 +143,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -194,8 +197,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -257,8 +261,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -291,8 +296,11 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account1, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
tenant1 = account1.current_tenant
account2 = AccountService.create_account(
@ -300,8 +308,11 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account2, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
tenant2 = account2.current_tenant
# Create dataset for tenant1
@ -367,8 +378,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Don't create any datasets
@ -391,8 +403,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create multiple datasets
@ -452,8 +465,9 @@ class TestKnowledgeRetrievalIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -520,8 +534,9 @@ class TestKnowledgeRetrievalIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset but no documents
@ -568,8 +583,9 @@ class TestKnowledgeRetrievalIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(

View File

@ -114,8 +114,9 @@ class TestAgentService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -81,8 +81,9 @@ class TestAnnotationService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Setup app creation arguments
@ -280,7 +281,9 @@ class TestAnnotationService:
"question": fake.sentence(),
"answer": fake.text(max_nb_chars=200),
}
updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id)
updated_annotation = AppAnnotationService.update_app_annotation_directly(
updated_args, app.id, annotation.id, db_session_with_containers
)
# Verify annotation was updated correctly
assert updated_annotation.id == annotation.id
@ -567,7 +570,7 @@ class TestAnnotationService:
annotation_id = annotation.id
# Delete the annotation
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers)
# Verify annotation was deleted
@ -595,7 +598,7 @@ class TestAnnotationService:
# Try to delete annotation with non-existent app
with pytest.raises(NotFound, match="App not found"):
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id)
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id, db_session_with_containers)
def test_delete_app_annotation_annotation_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -609,7 +612,7 @@ class TestAnnotationService:
# Try to delete non-existent annotation
with pytest.raises(NotFound, match="Annotation not found"):
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id)
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id, db_session_with_containers)
def test_enable_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -1225,7 +1228,9 @@ class TestAnnotationService:
"question": fake.sentence(),
"answer": fake.text(max_nb_chars=200),
}
updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id)
updated_annotation = AppAnnotationService.update_app_annotation_directly(
updated_args, app.id, annotation.id, db_session_with_containers
)
# Verify annotation was updated correctly
assert updated_annotation.id == annotation.id
@ -1295,7 +1300,7 @@ class TestAnnotationService:
mock_external_service_dependencies["delete_task"].delay.reset_mock()
# Delete the annotation
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers)
# Verify annotation was deleted
deleted_annotation = (

View File

@ -57,8 +57,9 @@ class TestAPIBasedExtensionService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
return account, tenant

View File

@ -145,8 +145,11 @@ class TestAppDslService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_args = CreateAppParams(
name=fake.company(),

View File

@ -166,8 +166,9 @@ class TestAppGenerateService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams

View File

@ -62,8 +62,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Setup app creation arguments
@ -119,8 +120,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -162,8 +164,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -210,8 +213,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -261,8 +265,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppListParams, AppService, CreateAppParams
@ -344,8 +349,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from models import AppStar
@ -404,8 +410,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams, StarredAppListParams
@ -500,8 +507,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -566,14 +574,18 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
first_account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(first_account, name=fake.company())
tenant = first_account.current_tenant
second_account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
from services.app_service import AppListParams, AppService, CreateAppParams
@ -623,8 +635,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -685,8 +698,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -755,8 +769,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams
@ -807,8 +822,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams
@ -857,8 +873,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -910,8 +927,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -971,8 +989,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1030,8 +1049,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1089,8 +1109,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1139,8 +1160,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1190,8 +1212,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1249,8 +1272,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1287,8 +1311,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1326,8 +1351,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1375,8 +1401,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
# Import here to avoid circular dependency
from services.app_service import CreateAppParams
@ -1411,8 +1438,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency

View File

@ -98,8 +98,9 @@ class TestMessageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Setup app creation arguments
@ -648,8 +649,11 @@ class TestMessageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
other_account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
# Test getting message with different user
with pytest.raises(MessageNotExistsError):

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import uuid
from typing import cast
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
from uuid import uuid4
import pytest
@ -172,4 +172,4 @@ class TestOAuthServerServiceTokenOperations:
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
assert result is expected_user
mock_load.assert_called_once_with("user-88")
mock_load.assert_called_once_with("user-88", ANY)

View File

@ -51,8 +51,9 @@ class TestOpsService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
app_service = AppService()
app = app_service.create_app(

View File

@ -68,8 +68,9 @@ class TestSavedMessageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -55,6 +55,7 @@ class TestTriggerProviderService:
def _create_test_account_and_tenant(
self,
mock_external_service_dependencies: MockExternalServiceDependencies,
db_session_with_containers: Session,
) -> tuple[Account, Tenant]:
"""
Helper method to create a test account and tenant for testing.
@ -83,8 +84,9 @@ class TestTriggerProviderService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
assert tenant is not None
@ -164,7 +166,9 @@ class TestTriggerProviderService:
- Database state is correctly updated
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -262,7 +266,9 @@ class TestTriggerProviderService:
- Merged credentials contain only new values
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -320,7 +326,9 @@ class TestTriggerProviderService:
- Original credentials are preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -376,7 +384,9 @@ class TestTriggerProviderService:
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -434,7 +444,9 @@ class TestTriggerProviderService:
- Original subscription state is preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -474,9 +486,8 @@ class TestTriggerProviderService:
assert subscription.name == original_name
assert subscription.parameters == original_parameters
@pytest.mark.usefixtures("db_session_with_containers")
def test_rebuild_trigger_subscription_subscription_not_found(
self, mock_external_service_dependencies: MockExternalServiceDependencies
self, mock_external_service_dependencies: MockExternalServiceDependencies, db_session_with_containers: Session
) -> None:
"""
Test error when subscription is not found.
@ -485,7 +496,9 @@ class TestTriggerProviderService:
- Proper error is raised when subscription doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
fake_subscription_id = fake.uuid4()
@ -509,7 +522,9 @@ class TestTriggerProviderService:
- Error is raised when new name conflicts with existing subscription
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY

View File

@ -72,8 +72,9 @@ class TestWebConversationService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -63,8 +63,9 @@ class TestWebhookService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
assert tenant is not None

View File

@ -78,8 +78,9 @@ class TestWorkflowAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -126,8 +127,9 @@ class TestWorkflowAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
return tenant, account

View File

@ -74,8 +74,9 @@ class TestWorkflowRunService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data
@ -530,8 +531,9 @@ class TestWorkflowRunService:
name="Test User",
password="password123",
interface_language="en-US",
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
tenant = account.current_tenant
# Create app
@ -581,8 +583,9 @@ class TestWorkflowRunService:
name="Test User",
password="password123",
interface_language="en-US",
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
tenant = account.current_tenant
# Create app
@ -632,8 +635,9 @@ class TestWorkflowRunService:
name="Test User",
password="password123",
interface_language="en-US",
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
tenant = account.current_tenant
# Create app

View File

@ -89,8 +89,9 @@ class TestWorkflowToolManageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -90,8 +90,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -211,8 +212,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -255,8 +257,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Test different index types
@ -342,8 +345,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -424,8 +428,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -525,8 +530,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -625,8 +631,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -717,8 +724,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -820,8 +828,11 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
accounts.append(account)
tenants.append(tenant)
@ -926,8 +937,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -1031,8 +1043,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset with built-in fields enabled

View File

@ -67,8 +67,9 @@ class TestDealDatasetVectorIndexTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
assert tenant is not None
return account, tenant

View File

@ -8,7 +8,7 @@ This module tests the account activation mechanism including:
- Initial login after activation
"""
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -41,7 +41,7 @@ class TestActivateCheckApi:
}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
def test_check_valid_invitation_token(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock):
"""
Test checking valid invitation token.
@ -67,7 +67,9 @@ class TestActivateCheckApi:
assert response["data"]["email"] == "invitee@example.com"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_valid_invitation_token_includes_account_status(self, mock_get_invitation, app, mock_invitation):
def test_check_valid_invitation_token_includes_account_status(
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
):
mock_account = MagicMock()
mock_account.status = AccountStatus.ACTIVE
mock_invitation["account"] = mock_account
@ -103,7 +105,9 @@ class TestActivateCheckApi:
assert response["is_valid"] is False
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
def test_check_token_without_workspace_id(
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
):
"""
Test checking token without workspace ID.
@ -121,10 +125,10 @@ class TestActivateCheckApi:
# Assert
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token", session=ANY)
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
def test_check_token_without_email(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation):
"""
Test checking token without email parameter.
@ -142,10 +146,12 @@ class TestActivateCheckApi:
# Assert
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token", session=ANY)
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
def test_check_token_normalizes_email_to_lowercase(
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
):
"""Ensure token validation uses lowercase emails."""
mock_get_invitation.return_value = mock_invitation
@ -156,7 +162,7 @@ class TestActivateCheckApi:
response = api.get()
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY)
class TestActivateApi:
@ -554,7 +560,7 @@ class TestActivateApi:
response = api.post()
assert response["result"] == "success"
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY)
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.TenantService.create_tenant_member")
@ -593,7 +599,7 @@ class TestActivateApi:
mock_create_tenant_member.assert_called_once_with(
mock_invitation["tenant"], mock_account, mock_db.session, role=TenantAccountRole.ADMIN
)
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id)
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY)
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.TenantService.create_tenant_member")
@ -628,5 +634,5 @@ class TestActivateApi:
assert response["result"] == "success"
mock_create_tenant_member.assert_not_called()
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id)
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY)
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from pydantic import ValidationError
@ -25,6 +25,7 @@ def test_create_new_account_uses_requested_language(mock_create_account):
password="ValidPass123!",
interface_language="zh-Hans",
timezone="Asia/Shanghai",
session=ANY,
)

View File

@ -9,7 +9,7 @@ This module tests the email code login mechanism including:
"""
import base64
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -368,6 +368,7 @@ class TestEmailCodeLoginApi:
name="newuser@example.com",
interface_language="en-US",
timezone="Asia/Shanghai",
session=ANY,
)
@patch("controllers.console.wraps.db")

View File

@ -9,7 +9,7 @@ This module tests the core authentication endpoints including:
"""
import base64
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import ANY, MagicMock, Mock, patch
import pytest
from flask import Flask
@ -129,7 +129,7 @@ class TestLoginApi:
response = login_api.post()
# Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None, session=ANY)
mock_login.assert_called_once()
mock_reset_rate_limit.assert_called_once_with("test@example.com")
assert response.json["result"] == "success"
@ -184,7 +184,7 @@ class TestLoginApi:
response = login_api.post()
# Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token")
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token", session=ANY)
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@ -407,13 +407,13 @@ class TestLoginApi:
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_login_retries_with_lowercase_email(
self,
mock_reset_rate_limit,
mock_login_service,
mock_get_tenants,
mock_add_rate_limit,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_reset_rate_limit: MagicMock,
mock_login_service: MagicMock,
mock_get_tenants: MagicMock,
mock_add_rate_limit: MagicMock,
mock_authenticate: MagicMock,
mock_get_invitation: MagicMock,
mock_is_rate_limit: MagicMock,
mock_db,
app: Flask,
mock_account,
@ -435,8 +435,8 @@ class TestLoginApi:
assert response.json["result"] == "success"
assert mock_authenticate.call_args_list == [
(("Upper@Example.com", "ValidPass123!", None), {}),
(("upper@example.com", "ValidPass123!", None), {}),
(("Upper@Example.com", "ValidPass123!", None), {"session": ANY}),
(("upper@example.com", "ValidPass123!", None), {"session": ANY}),
]
mock_add_rate_limit.assert_not_called()
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
@ -447,10 +447,10 @@ class TestLoginApi:
@patch("controllers.console.auth.login._get_account_with_case_fallback")
def test_email_code_login_logs_banned_account(
self,
mock_get_account,
mock_revoke_token,
mock_get_token_data,
mock_db,
mock_get_account: MagicMock,
mock_revoke_token: MagicMock,
mock_get_token_data: MagicMock,
mock_db: MagicMock,
app: Flask,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
@ -491,7 +491,9 @@ class TestLogoutApi:
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account):
def test_successful_logout(
self, mock_logout_user: MagicMock, mock_service_logout: MagicMock, app: Flask, mock_account
):
"""
Test successful logout flow.

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -66,8 +66,9 @@ def test_generate_account_registers_with_browser_timezone(
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
session=ANY,
)
mock_link_account.assert_called_once_with("github", "github-123", account)
mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY)
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
@ -97,8 +98,9 @@ def test_generate_account_prefers_state_language_over_accept_language(
provider="github",
language="zh-Hans",
timezone=None,
session=ANY,
)
mock_link_account.assert_called_once_with("github", "github-123", account)
mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY)
@patch("controllers.console.auth.oauth.dify_config")

View File

@ -8,7 +8,7 @@ This module tests the token refresh mechanism including:
- Error handling for invalid tokens
"""
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -70,7 +70,7 @@ class TestRefreshTokenApi:
# Assert
mock_extract_token.assert_called_once()
mock_refresh_token.assert_called_once_with("valid_refresh_token")
mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY)
assert response.json["result"] == "success"
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@ -191,7 +191,7 @@ class TestRefreshTokenApi:
# Assert
assert response.json["result"] == "success"
# Verify new token pair was generated
mock_refresh_token.assert_called_once_with("valid_refresh_token")
mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY)
# In real implementation, cookies would be set with new values
assert mock_token_pair.access_token == "new_access_token"
assert mock_token_pair.refresh_token == "new_refresh_token"

View File

@ -38,7 +38,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 1)
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
@ -48,7 +48,7 @@ def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.Mo
def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
@ -60,7 +60,7 @@ def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.M
def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"

View File

@ -1,6 +1,6 @@
import inspect
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -394,12 +394,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_new_email_before_update(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
current_user = _build_account("old@example.com", "acc3")
@ -424,9 +424,9 @@ class TestChangeEmailReset:
method(api, current_user)
mock_is_freeze.assert_called_once_with("new@example.com")
mock_check_unique.assert_called_once_with("new@example.com")
mock_check_unique.assert_called_once_with("new@example.com", session=ANY)
mock_revoke_token.assert_called_once_with("token-123")
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
mock_update_account.assert_called_once_with(current_user, email="new@example.com", session=ANY)
mock_send_notify.assert_called_once_with(email="new@example.com")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@ -437,12 +437,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_reject_reset_when_token_phase_is_not_new_verified(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
@ -480,12 +480,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
"""A verified token for address A must not be replayed to change to address B."""
@ -523,12 +523,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_reject_reset_when_token_account_id_does_not_match_current_user(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
from controllers.console.auth.error import InvalidTokenError
@ -575,9 +575,9 @@ class TestAccountServiceSendChangeEmailEmail:
@patch("services.account_service.AccountService.generate_change_email_token")
def test_should_bind_account_id_and_target_email_into_generated_token(
self,
mock_generate_token,
mock_rate_limiter,
mock_mail_task,
mock_generate_token: MagicMock,
mock_rate_limiter: MagicMock,
mock_mail_task: MagicMock,
):
mock_rate_limiter.is_rate_limited.return_value = False
mock_generate_token.return_value = "the-token"
@ -665,7 +665,7 @@ class TestAccountDeletionFeedback:
class TestCheckEmailUnique:
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask):
def test_should_normalize_email(self, mock_is_freeze: MagicMock, mock_check_unique: MagicMock, app: Flask):
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
@ -680,7 +680,7 @@ class TestCheckEmailUnique:
assert response == {"result": "success"}
mock_is_freeze.assert_called_once_with("case@test.com")
mock_check_unique.assert_called_once_with("case@test.com")
mock_check_unique.assert_called_once_with("case@test.com", session=ANY)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():

View File

@ -8,7 +8,7 @@ handler tests use inspect.unwrap() to bypass them and focus on business logic.
import inspect
from datetime import datetime
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -115,7 +115,7 @@ class TestEnterpriseWorkspace:
assert result["message"] == "enterprise workspace created."
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["name"] == "My Workspace"
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY)
mock_tenant_svc.create_tenant_member.assert_called_once_with(
mock_tenant, mock_account, mock_db.session, role="owner"
)
@ -183,5 +183,5 @@ class TestEnterpriseWorkspaceNoOwnerEmail:
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["encrypt_public_key"] == "pub-key"
assert result["tenant"]["custom_config"] == {}
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY)
mock_event.send.assert_called_once_with(mock_tenant)

View File

@ -179,7 +179,7 @@ class TestAccountService:
mock_password_dependencies["compare_password"].return_value = True
# Execute test
result = AccountService.authenticate("test@example.com", "password")
result = AccountService.authenticate("test@example.com", "password", session=mock_db_dependencies["db"].session)
# Verify results
assert result == mock_account
@ -191,7 +191,11 @@ class TestAccountService:
# Execute test and verify exception
self._assert_exception_raised(
AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password"
AccountPasswordError,
AccountService.authenticate,
"notfound@example.com",
"password",
session=mock_db_dependencies["db"].session,
)
def test_authenticate_account_banned(self, mock_db_dependencies):
@ -202,7 +206,13 @@ class TestAccountService:
mock_db_dependencies["db"].session.scalar.return_value = mock_account
# Execute test and verify exception
self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
self._assert_exception_raised(
AccountLoginError,
AccountService.authenticate,
"banned@example.com",
"password",
session=mock_db_dependencies["db"].session,
)
def test_authenticate_password_error(self, mock_db_dependencies, mock_password_dependencies):
"""Test authentication with wrong password."""
@ -215,7 +225,11 @@ class TestAccountService:
# Execute test and verify exception
self._assert_exception_raised(
AccountPasswordError, AccountService.authenticate, "test@example.com", "wrongpassword"
AccountPasswordError,
AccountService.authenticate,
"test@example.com",
"wrongpassword",
session=mock_db_dependencies["db"].session,
)
def test_authenticate_pending_account_activates(self, mock_db_dependencies, mock_password_dependencies):
@ -228,7 +242,9 @@ class TestAccountService:
mock_password_dependencies["compare_password"].return_value = True
# Execute test
result = AccountService.authenticate("pending@example.com", "password")
result = AccountService.authenticate(
"pending@example.com", "password", session=mock_db_dependencies["db"].session
)
# Verify results
assert result == mock_account
@ -253,6 +269,7 @@ class TestAccountService:
interface_language="en-US",
password="password123",
interface_theme="light",
session=mock_db_dependencies["db"].session,
)
# Verify results
@ -290,6 +307,7 @@ class TestAccountService:
interface_language="en-US",
password="password123",
timezone="Asia/Shanghai",
session=mock_db_dependencies["db"].session,
)
assert result.timezone == "Asia/Shanghai"
@ -309,6 +327,7 @@ class TestAccountService:
email="test@example.com",
name="Test User",
interface_language="en-US",
session=MagicMock(),
)
def test_create_account_email_frozen(self, mock_db_dependencies, mock_external_service_dependencies):
@ -325,6 +344,7 @@ class TestAccountService:
email="frozen@example.com",
name="Test User",
interface_language="en-US",
session=mock_db_dependencies["db"].session,
)
dify_config.BILLING_ENABLED = False
@ -341,6 +361,7 @@ class TestAccountService:
interface_language="zh-CN",
password=None,
interface_theme="dark",
session=mock_db_dependencies["db"].session,
)
# Verify results
@ -375,7 +396,9 @@ class TestAccountService:
mock_password_dependencies["hash_password"].return_value = b"new_hashed_password"
# Execute test
result = AccountService.update_account_password(mock_account, "old_password", "new_password123")
result = AccountService.update_account_password(
mock_account, "old_password", "new_password123", session=mock_db_dependencies["db"].session
)
# Verify results
assert result == mock_account
@ -391,7 +414,7 @@ class TestAccountService:
# Verify database operations
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_update_account_password_current_password_incorrect(self, mock_password_dependencies):
def test_update_account_password_current_password_incorrect(self, mock_db_dependencies, mock_password_dependencies):
"""Test password update with incorrect current password."""
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
@ -404,6 +427,7 @@ class TestAccountService:
mock_account,
"wrong_password",
"new_password123",
session=mock_db_dependencies["db"].session,
)
# Verify password comparison was called
@ -411,7 +435,7 @@ class TestAccountService:
"wrong_password", "hashed_password", "salt"
)
def test_update_account_password_invalid_new_password(self, mock_password_dependencies):
def test_update_account_password_invalid_new_password(self, mock_db_dependencies, mock_password_dependencies):
"""Test password update with invalid new password."""
# Setup test data
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
@ -420,7 +444,12 @@ class TestAccountService:
# Execute test and verify exception
self._assert_exception_raised(
ValueError, AccountService.update_account_password, mock_account, "old_password", "short"
ValueError,
AccountService.update_account_password,
mock_account,
"old_password",
"short",
session=mock_db_dependencies["db"].session,
)
# Verify password validation was called
@ -447,19 +476,19 @@ class TestAccountService:
mock_datetime.UTC = "UTC"
# Execute test
result = AccountService.load_user("user-123")
result = AccountService.load_user("user-123", mock_db_dependencies["db"].session)
# Verify results
assert result == mock_account
assert mock_account.set_tenant_id.called
mock_refresh_last_active.assert_called_once_with(mock_account)
mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session)
def test_load_user_not_found(self, mock_db_dependencies):
"""Test user loading when user does not exist."""
mock_db_dependencies["db"].session.get.return_value = None
# Execute test
result = AccountService.load_user("non-existent-user")
result = AccountService.load_user("non-existent-user", mock_db_dependencies["db"].session)
# Verify results
assert result is None
@ -500,14 +529,14 @@ class TestAccountService:
mock_naive_utc_now.return_value = mock_now
# Execute test
result = AccountService.load_user("user-123")
result = AccountService.load_user("user-123", mock_db_dependencies["db"].session)
# Verify results
assert result == mock_account
assert mock_available_tenant.current is True
assert mock_available_tenant.last_opened_at == mock_now
self._assert_database_operations_called(mock_db_dependencies["db"])
mock_refresh_last_active.assert_called_once_with(mock_account)
mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session)
def test_load_user_no_tenants(self, mock_db_dependencies):
"""Test user loading when user has no tenants at all."""
@ -525,7 +554,7 @@ class TestAccountService:
mock_datetime.UTC = "UTC"
# Execute test
result = AccountService.load_user("user-123")
result = AccountService.load_user("user-123", mock_db_dependencies["db"].session)
# Verify results
assert result is None
@ -542,7 +571,7 @@ class TestAccountService:
):
mock_redis_client.set.return_value = True
AccountService._refresh_account_last_active(mock_account)
AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session)
mock_redis_client.set.assert_called_once_with(
"account_last_active_refresh:user-123",
@ -565,7 +594,7 @@ class TestAccountService:
):
mock_redis_client.set.return_value = None
AccountService._refresh_account_last_active(mock_account)
AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session)
mock_redis_client.set.assert_called_once_with(
"account_last_active_refresh:user-123",
@ -586,7 +615,7 @@ class TestAccountService:
patch("services.account_service.naive_utc_now", return_value=now),
patch("services.account_service.redis_client") as mock_redis_client,
):
AccountService._refresh_account_last_active(mock_account)
AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session)
mock_redis_client.set.assert_not_called()
mock_db_dependencies["db"].session.execute.assert_not_called()
@ -736,7 +765,9 @@ class TestTenantService:
mock_credit_pool_db.session.commit = MagicMock()
# Execute test
TenantService.create_owner_tenant_if_not_exist(mock_account)
TenantService.create_owner_tenant_if_not_exist(
mock_account, session=mock_db_dependencies["db"].session
)
# Verify tenant was created with correct parameters
mock_db_dependencies["db"].session.add.assert_called()
@ -838,7 +869,9 @@ class TestTenantService:
mock_sync.return_value = True
# Act
TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator)
TenantService.remove_member_from_tenant(
mock_tenant, mock_pending_member, mock_operator, session=mock_db.session
)
# Assert: enterprise sync still receives the correct member ID
mock_sync.assert_called_once_with(
@ -878,7 +911,9 @@ class TestTenantService:
mock_sync.return_value = True
# Act
TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator)
TenantService.remove_member_from_tenant(
mock_tenant, mock_pending_member, mock_operator, session=mock_db.session
)
# Assert: only the join record should be deleted, not the account
mock_db.session.delete.assert_called_once_with(mock_ta)
@ -909,7 +944,9 @@ class TestTenantService:
mock_sync.return_value = True
# Act
TenantService.remove_member_from_tenant(mock_tenant, mock_active_member, mock_operator)
TenantService.remove_member_from_tenant(
mock_tenant, mock_active_member, mock_operator, session=mock_db.session
)
# Assert: only the join record should be deleted
mock_db.session.delete.assert_called_once_with(mock_ta)
@ -934,7 +971,7 @@ class TestTenantService:
mock_naive_utc_now.return_value = mock_now
# Execute test
TenantService.switch_tenant(mock_account, "tenant-456")
TenantService.switch_tenant(mock_account, "tenant-456", session=mock_db.session)
# Verify tenant was switched
assert mock_tenant_join.current is True
@ -947,7 +984,7 @@ class TestTenantService:
mock_account = TestAccountAssociatedDataFactory.create_account_mock()
# Execute test and verify exception
self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None)
self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None, session=MagicMock())
# ==================== Role Management Tests ====================
@ -971,7 +1008,7 @@ class TestTenantService:
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
# Execute test
TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator)
TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator, session=mock_db.session)
# Verify role was updated
assert mock_target_join.role == "admin"
@ -1005,7 +1042,9 @@ class TestTenantService:
):
mock_db_dependencies["db"].session.scalar.return_value = None
TenantService.create_owner_tenant_if_not_exist(mock_account, is_setup=True)
TenantService.create_owner_tenant_if_not_exist(
mock_account, is_setup=True, session=mock_db_dependencies["db"].session
)
mock_rbac_service.MemberRoles.replace.assert_called_once_with(
tenant_id="tenant-rbac",
@ -1030,7 +1069,7 @@ class TestTenantService:
with patch("services.account_service.db") as mock_db:
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator)
TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session)
assert mock_target_join.role == "editor"
self._assert_database_operations_called(mock_db)
@ -1052,7 +1091,9 @@ class TestTenantService:
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
with pytest.raises(NoPermissionError):
TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator)
TenantService.update_member_role(
mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session
)
def test_admin_cannot_promote_member_to_owner(self):
"""Test admin cannot promote a non-owner member to owner."""
@ -1071,7 +1112,9 @@ class TestTenantService:
mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join]
with pytest.raises(NoPermissionError):
TenantService.update_member_role(mock_tenant, mock_member, "owner", mock_operator)
TenantService.update_member_role(
mock_tenant, mock_member, "owner", mock_operator, session=mock_db.session
)
# ==================== Permission Check Tests ====================
@ -1089,7 +1132,9 @@ class TestTenantService:
mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join
# Execute test - should not raise exception
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")
TenantService.check_member_permission(
mock_tenant, mock_operator, mock_member, "add", session=mock_db_dependencies["db"].session
)
def test_check_member_permission_operate_self(self):
"""Test member permission check when operator tries to operate self."""
@ -1108,6 +1153,7 @@ class TestTenantService:
mock_operator,
mock_operator, # Same as operator
"add",
session=MagicMock(),
)
def test_admin_can_remove_non_owner_member(self, mock_db_dependencies):
@ -1124,7 +1170,9 @@ class TestTenantService:
)
mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join]
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
TenantService.check_member_permission(
mock_tenant, mock_operator, mock_member, "remove", session=mock_db_dependencies["db"].session
)
def test_admin_cannot_remove_owner_member(self, mock_db_dependencies):
"""Test admin cannot remove an owner member."""
@ -1141,7 +1189,9 @@ class TestTenantService:
mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join]
with pytest.raises(NoPermissionError):
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
TenantService.check_member_permission(
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
)
def test_rbac_member_can_remove_non_owner_member(self):
"""Test RBAC workspace.member.manage allows removing a non-owner member."""
@ -1158,7 +1208,9 @@ class TestTenantService:
patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions),
patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=False),
):
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
TenantService.check_member_permission(
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
)
def test_rbac_member_cannot_remove_without_permission(self):
"""Test RBAC permission check rejects removal without workspace.member.manage."""
@ -1175,7 +1227,9 @@ class TestTenantService:
patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions),
):
with pytest.raises(NoPermissionError):
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
TenantService.check_member_permission(
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
)
def test_rbac_member_cannot_remove_owner_member(self):
"""Test RBAC permission check rejects removing an owner member."""
@ -1193,7 +1247,9 @@ class TestTenantService:
patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=True),
):
with pytest.raises(NoPermissionError):
TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove")
TenantService.check_member_permission(
mock_tenant, mock_operator, mock_member, "remove", session=MagicMock()
)
def test_get_rbac_workspace_owner_account_id(self):
mock_roles = MagicMock()
@ -1304,7 +1360,14 @@ class TestRegisterService:
mock_dify_setup.return_value = mock_dify_setup_instance
# Execute test
RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1", "en-US")
RegisterService.setup(
"admin@example.com",
"Admin User",
"password123",
"192.168.1.1",
"en-US",
session=mock_db_dependencies["db"].session,
)
# Verify results
mock_create_account.assert_called_once_with(
@ -1313,8 +1376,11 @@ class TestRegisterService:
interface_language="en-US",
password="password123",
is_setup=True,
session=mock_db_dependencies["db"].session,
)
mock_create_tenant.assert_called_once_with(
account=mock_account, is_setup=True, session=mock_db_dependencies["db"].session
)
mock_create_tenant.assert_called_once_with(account=mock_account, is_setup=True)
mock_dify_setup.assert_called_once()
self._assert_database_operations_called(mock_db_dependencies["db"])
@ -1337,6 +1403,7 @@ class TestRegisterService:
"password123",
"192.168.1.1",
"en-US",
session=mock_db_dependencies["db"].session,
)
# Verify rollback operations were called
@ -1369,10 +1436,13 @@ class TestRegisterService:
name="Test User",
interface_language="en-US",
password=None,
session=mock_db_dependencies["db"].session,
)
assert result == mock_account
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_create_workspace.assert_called_once_with(
account=mock_account, session=mock_db_dependencies["db"].session
)
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled(
@ -1400,9 +1470,12 @@ class TestRegisterService:
name="Test User",
interface_language="en-US",
password=None,
session=mock_db_dependencies["db"].session,
)
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_create_workspace.assert_called_once_with(
account=mock_account, session=mock_db_dependencies["db"].session
)
mock_join_default_workspace.assert_not_called()
def test_create_account_and_tenant_still_calls_default_workspace_join_when_workspace_creation_fails(
@ -1433,6 +1506,7 @@ class TestRegisterService:
name="Test User",
interface_language="en-US",
password=None,
session=mock_db_dependencies["db"].session,
)
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
@ -1470,6 +1544,7 @@ class TestRegisterService:
name="Test User",
password="password123",
language="en-US",
session=mock_db_dependencies["db"].session,
)
# Verify results
@ -1483,8 +1558,11 @@ class TestRegisterService:
password="password123",
is_setup=False,
timezone=None,
session=mock_db_dependencies["db"].session,
)
mock_create_tenant.assert_called_once_with(
"Test User's Workspace", session=mock_db_dependencies["db"].session
)
mock_create_tenant.assert_called_once_with("Test User's Workspace")
mock_create_member.assert_called_once_with(
mock_tenant, mock_account, mock_db_dependencies["db"].session, role="owner"
)
@ -1516,6 +1594,7 @@ class TestRegisterService:
password="password123",
language="en-US",
create_workspace_required=False,
session=mock_db_dependencies["db"].session,
)
assert result == mock_account
@ -1546,6 +1625,7 @@ class TestRegisterService:
password="password123",
language="en-US",
create_workspace_required=False,
session=mock_db_dependencies["db"].session,
)
mock_join_default_workspace.assert_not_called()
@ -1584,6 +1664,7 @@ class TestRegisterService:
name="Test User",
password="password123",
language="en-US",
session=mock_db_dependencies["db"].session,
)
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
@ -1623,6 +1704,7 @@ class TestRegisterService:
name="Test User",
password="password123",
language="en-US",
session=mock_db_dependencies["db"].session,
)
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
@ -1665,11 +1747,14 @@ class TestRegisterService:
open_id="oauth123",
provider="google",
language="en-US",
session=mock_db_dependencies["db"].session,
)
# Verify results
assert result == mock_account
mock_link_account.assert_called_once_with("google", "oauth123", mock_account)
mock_link_account.assert_called_once_with(
"google", "oauth123", mock_account, session=mock_db_dependencies["db"].session
)
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_register_with_pending_status(self, mock_db_dependencies, mock_external_service_dependencies):
@ -1707,6 +1792,7 @@ class TestRegisterService:
password="password123",
language="en-US",
status=AccountStatus.PENDING,
session=mock_db_dependencies["db"].session,
)
# Verify results
@ -1744,6 +1830,7 @@ class TestRegisterService:
name="Test User",
password="password123",
language="en-US",
session=mock_db_dependencies["db"].session,
)
# Verify rollback was called
@ -1767,6 +1854,7 @@ class TestRegisterService:
name="Test User",
password="password123",
language="en-US",
session=mock_db_dependencies["db"].session,
)
# Verify rollback was called
@ -1810,6 +1898,7 @@ class TestRegisterService:
language="en-US",
role="normal",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
# Verify results
@ -1820,6 +1909,7 @@ class TestRegisterService:
language="en-US",
status=AccountStatus.PENDING,
is_setup=True,
session=mock_db_dependencies["db"].session,
)
mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "newuser@example.com")
@ -1856,6 +1946,7 @@ class TestRegisterService:
language="en-US",
role="normal",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
mock_register.assert_called_once_with(
@ -1864,13 +1955,22 @@ class TestRegisterService:
language="en-US",
status=AccountStatus.PENDING,
is_setup=True,
session=mock_db_dependencies["db"].session,
)
mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, mixed_email)
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
mock_check_permission.assert_called_once_with(
mock_tenant,
mock_inviter,
None,
"add",
session=mock_db_dependencies["db"].session,
)
mock_create_member.assert_called_once_with(
mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal"
)
mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
mock_switch_tenant.assert_called_once_with(
mock_new_account, mock_tenant.id, session=mock_db_dependencies["db"].session
)
mock_generate_token.assert_called_once_with(
mock_tenant, mock_new_account, "normal", requires_setup=True
)
@ -1912,6 +2012,7 @@ class TestRegisterService:
language="en-US",
role="normal",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
# Verify results
@ -1954,10 +2055,17 @@ class TestRegisterService:
language="en-US",
role="admin",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
assert result == "invite-token-123"
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, mock_existing_account, "add")
mock_check_permission.assert_called_once_with(
mock_tenant,
mock_inviter,
mock_existing_account,
"add",
session=mock_db_dependencies["db"].session,
)
mock_create_member.assert_not_called()
mock_generate_token.assert_called_once_with(
mock_tenant, mock_existing_account, "admin", requires_setup=False
@ -1993,6 +2101,7 @@ class TestRegisterService:
language="en-US",
role="normal",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
mock_lookup.assert_called_once()
@ -2010,6 +2119,7 @@ class TestRegisterService:
language="en-US",
role="normal",
inviter=None,
session=MagicMock(),
)
# ==================== RBAC Member Invitation Tests ====================
@ -2048,6 +2158,7 @@ class TestRegisterService:
language="en-US",
role="rbac-role-id-123",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
assert result == "rbac-token"
@ -2093,6 +2204,7 @@ class TestRegisterService:
language="en-US",
role="rbac-role-id-456",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
assert result == "rbac-token"
@ -2141,6 +2253,7 @@ class TestRegisterService:
language="en-US",
role="rbac-role-id-456",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
mock_create_member.assert_called_once_with(
@ -2191,6 +2304,7 @@ class TestRegisterService:
language="en-US",
role="editor",
inviter=mock_inviter,
session=mock_db_dependencies["db"].session,
)
assert result == "legacy-token"
@ -2300,7 +2414,9 @@ class TestRegisterService:
mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account]
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
result = RegisterService.get_invitation_if_token_valid(
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
)
# Verify results
assert result is not None
@ -2314,7 +2430,9 @@ class TestRegisterService:
mock_redis_dependencies.get.return_value = None
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
result = RegisterService.get_invitation_if_token_valid(
"tenant-456", "test@example.com", "token-123", session=MagicMock()
)
# Verify results
assert result is None
@ -2333,7 +2451,9 @@ class TestRegisterService:
mock_db_dependencies["db"].session.scalar.return_value = None
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
result = RegisterService.get_invitation_if_token_valid(
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
)
# Verify results
assert result is None
@ -2357,7 +2477,9 @@ class TestRegisterService:
mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, None]
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
result = RegisterService.get_invitation_if_token_valid(
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
)
# Verify results
assert result is None
@ -2384,7 +2506,9 @@ class TestRegisterService:
mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account]
# Execute test
result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
result = RegisterService.get_invitation_if_token_valid(
"tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session
)
# Verify results
assert result is None
@ -2395,22 +2519,28 @@ class TestRegisterService:
with patch(
"services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation
) as mock_get:
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
result = RegisterService.get_invitation_with_case_fallback(
"tenant-456", "User@Test.com", "token-123", session=MagicMock()
)
assert result == invitation
mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123")
mock_get.assert_called_once_with(
"tenant-456", "User@Test.com", "token-123", session=mock_get.call_args.kwargs["session"]
)
def test_get_invitation_with_case_fallback_retries_with_lowercase(self):
"""Fallback helper should retry with lowercase email when needed."""
invitation = {"workspace_id": "tenant-456"}
with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get:
mock_get.side_effect = [None, invitation]
result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123")
result = RegisterService.get_invitation_with_case_fallback(
"tenant-456", "User@Test.com", "token-123", session=MagicMock()
)
assert result == invitation
assert mock_get.call_args_list == [
(("tenant-456", "User@Test.com", "token-123"),),
(("tenant-456", "user@test.com", "token-123"),),
(("tenant-456", "User@Test.com", "token-123"), {"session": mock_get.call_args_list[0].kwargs["session"]}),
(("tenant-456", "user@test.com", "token-123"), {"session": mock_get.call_args_list[1].kwargs["session"]}),
]
# ==================== Helper Method Tests ====================

View File

@ -546,7 +546,7 @@ class TestAppAnnotationServiceDirectManipulation:
# Act & Assert
with pytest.raises(NotFound):
AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1")
AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1", mock_db.session)
def test_update_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None:
"""Test missing app raises NotFound in update path."""
@ -562,7 +562,7 @@ class TestAppAnnotationServiceDirectManipulation:
# Act & Assert
with pytest.raises(NotFound):
AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1")
AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1", mock_db.session)
def test_update_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None:
"""Test missing question raises ValueError."""
@ -581,7 +581,7 @@ class TestAppAnnotationServiceDirectManipulation:
# Act & Assert
with pytest.raises(ValueError):
AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id)
AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session)
def test_update_app_annotation_directly_should_update_annotation_and_index(self) -> None:
"""Test update changes fields and triggers index update."""
@ -602,7 +602,7 @@ class TestAppAnnotationServiceDirectManipulation:
mock_db.session.get.return_value = annotation
# Act
result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id)
result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session)
# Assert
assert result == annotation
@ -640,7 +640,7 @@ class TestAppAnnotationServiceDirectManipulation:
mock_db.session.scalars.return_value = scalars_result
# Act
AppAnnotationService.delete_app_annotation(app.id, annotation.id)
AppAnnotationService.delete_app_annotation(app.id, annotation.id, mock_db.session)
# Assert
mock_db.session.delete.assert_any_call(annotation)
@ -667,7 +667,7 @@ class TestAppAnnotationServiceDirectManipulation:
# Act & Assert
with pytest.raises(NotFound):
AppAnnotationService.delete_app_annotation("app-1", "ann-1")
AppAnnotationService.delete_app_annotation("app-1", "ann-1", mock_db.session)
def test_delete_app_annotation_should_raise_not_found_when_annotation_missing(self) -> None:
"""Test delete raises NotFound when annotation is missing."""
@ -684,7 +684,7 @@ class TestAppAnnotationServiceDirectManipulation:
# Act & Assert
with pytest.raises(NotFound):
AppAnnotationService.delete_app_annotation(app.id, "ann-1")
AppAnnotationService.delete_app_annotation(app.id, "ann-1", mock_db.session)
def test_delete_app_annotations_in_batch_should_return_zero_when_none_found(self) -> None:
"""Test batch delete returns zero when no annotations found."""