From 21121159624b307aa705c1e6b6c95d288bf75a1f Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 24 Jun 2026 16:29:12 +0900 Subject: [PATCH] chore: make AccountService.load_user use passed session (#37764) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/commands/account.py | 3 +- api/controllers/console/app/annotation.py | 7 +- .../console/app/workflow_comment.py | 3 +- api/controllers/console/auth/activate.py | 10 +- .../console/auth/email_register.py | 3 +- .../console/auth/forgot_password.py | 4 +- api/controllers/console/auth/login.py | 25 +- api/controllers/console/auth/oauth.py | 10 +- .../console/explore/installed_app.py | 2 +- api/controllers/console/init_validate.py | 2 +- api/controllers/console/setup.py | 3 +- api/controllers/console/socketio/workflow.py | 3 +- api/controllers/console/workspace/account.py | 25 +- api/controllers/console/workspace/members.py | 25 +- .../console/workspace/workspace.py | 6 +- .../inner_api/workspace/workspace.py | 4 +- api/controllers/openapi/workspaces.py | 9 +- api/controllers/service_api/app/annotation.py | 7 +- .../processor/paragraph_index_processor.py | 2 +- .../processor/parent_child_index_processor.py | 2 +- api/extensions/ext_login.py | 4 +- api/libs/password.py | 2 +- api/services/account_service.py | 331 +++++++++------- api/services/annotation_service.py | 29 +- api/services/oauth_server.py | 2 +- api/services/workspace_service.py | 4 +- api/tests/integration_tests/conftest.py | 1 + .../controllers/console/auth/test_oauth.py | 6 +- .../controllers/openapi/conftest.py | 18 +- .../controllers/openapi/test_account.py | 7 +- .../controllers/openapi/test_app_dsl.py | 3 +- .../controllers/openapi/test_workspaces.py | 6 +- .../test_dataset_retrieval_integration.py | 40 +- .../services/test_account_service.py | 363 +++++++++++++----- .../services/test_agent_service.py | 3 +- .../services/test_annotation_service.py | 19 +- .../test_api_based_extension_service.py | 3 +- .../services/test_app_dsl_service.py | 5 +- .../services/test_app_generate_service.py | 3 +- .../services/test_app_service.py | 78 ++-- .../services/test_message_service.py | 8 +- .../services/test_oauth_server_service.py | 4 +- .../services/test_ops_service.py | 3 +- .../services/test_saved_message_service.py | 3 +- .../services/test_trigger_provider_service.py | 35 +- .../services/test_web_conversation_service.py | 3 +- .../services/test_webhook_service.py | 3 +- .../services/test_workflow_app_service.py | 6 +- .../services/test_workflow_run_service.py | 12 +- .../test_workflow_tools_manage_service.py | 3 +- .../tasks/test_clean_notion_document_task.py | 35 +- .../test_deal_dataset_vector_index_task.py | 3 +- .../console/auth/test_account_activation.py | 30 +- .../auth/test_email_register_language.py | 3 +- .../console/auth/test_email_verification.py | 3 +- .../console/auth/test_login_logout.py | 36 +- .../console/auth/test_oauth_timezone.py | 8 +- .../console/auth/test_token_refresh.py | 6 +- .../controllers/console/test_init_validate.py | 6 +- .../console/test_workspace_account.py | 64 +-- .../inner_api/workspace/test_workspace.py | 6 +- .../services/test_account_service.py | 238 +++++++++--- .../services/test_annotation_service.py | 14 +- 63 files changed, 1060 insertions(+), 554 deletions(-) diff --git a/api/commands/account.py b/api/commands/account.py index 7f4f0a744f3..dfd57d43142 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -133,8 +133,9 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No password=new_password, language=language, create_workspace_required=False, + session=db.session, ) - TenantService.create_owner_tenant_if_not_exist(account, name) + TenantService.create_owner_tenant_if_not_exist(account, name, session=db.session) click.echo( click.style( diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index edf3a98af8c..48fb4aedc63 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -19,6 +19,7 @@ from controllers.console.wraps import ( rbac_permission_required, setup_required, ) +from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.annotation_fields import ( Annotation, @@ -388,7 +389,9 @@ class AnnotationUpdateDeleteApi(Resource): update_args["answer"] = args.answer if args.question is not None: update_args["question"] = args.question - annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id)) + annotation = AppAnnotationService.update_app_annotation_directly( + update_args, str(app_id), str(annotation_id), db.session + ) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @@ -398,7 +401,7 @@ class AnnotationUpdateDeleteApi(Resource): @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT) @console_ns.response(204, "Annotation deleted successfully") def delete(self, app_id: UUID, annotation_id: UUID): - AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id)) + AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id), db.session) return "", 204 diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index a9bf85ed36c..c70f00dcfa1 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -17,6 +17,7 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from extensions.ext_database import db from fields.base import ResponseModel from fields.member_fields import AccountWithRole from libs.helper import build_avatar_url, dump_response, to_timestamp @@ -489,7 +490,7 @@ class WorkflowCommentMentionUsersApi(Resource): current_tenant = current_user.current_tenant # need the tenant object here if current_tenant is None: raise ValueError("current tenant is required") - members = TenantService.get_tenant_members(current_tenant) + members = TenantService.get_tenant_members(current_tenant, session=db.session) users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) response = WorkflowCommentMentionUsersPayload(users=users) return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index c9142d85ede..b6045685b55 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -89,7 +89,9 @@ class ActivateCheckApi(Resource): workspaceId = args.workspace_id token = args.token - invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token) + invitation = RegisterService.get_invitation_with_case_fallback( + workspaceId, args.email, token, session=db.session + ) if invitation: data = invitation.get("data", {}) tenant = invitation.get("tenant", None) @@ -137,7 +139,9 @@ class ActivateApi(Resource): args = ActivatePayload.model_validate(console_ns.payload) normalized_request_email = args.email.lower() if args.email else None - invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) + invitation = RegisterService.get_invitation_with_case_fallback( + args.workspace_id, args.email, args.token, session=db.session + ) if invitation is None: raise AlreadyActivateError() @@ -184,6 +188,6 @@ class ActivateApi(Resource): account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() - TenantService.switch_tenant(account, tenant.id) + TenantService.switch_tenant(account, tenant.id, session=db.session) return {"result": "success"} diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index ccbe9405fe5..ba4fc1275d9 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -187,7 +187,7 @@ class EmailRegisterResetApi(Resource): timezone=args.timezone, language=args.language, ) - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} @@ -206,6 +206,7 @@ class EmailRegisterResetApi(Resource): password=password, interface_language=get_valid_language(language), timezone=timezone, + session=db.session, ) except AccountRegisterError: raise AccountInFreezeError() diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 061c29a13a2..8df9600070c 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -198,10 +198,10 @@ class ForgotPasswordResetApi(Resource): # Create workspace if needed if ( - not TenantService.get_join_tenants(account) + not TenantService.get_join_tenants(account, session=db.session) and FeatureService.get_system_features().is_allow_create_workspace ): - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session) TenantService.create_tenant_member(tenant, account, db.session, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 053f313ba53..81f9ee4bae4 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -119,7 +119,9 @@ class LoginApi(Resource): invite_token = args.invite_token invitation_data: InvitationDetailDict | None = None if invite_token: - invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) + invitation_data = RegisterService.get_invitation_with_case_fallback( + None, request_email, invite_token, session=db.session + ) if invitation_data is None: invite_token = None @@ -145,7 +147,7 @@ class LoginApi(Resource): _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS) raise AuthenticationFailedError() from exc # SELF_HOSTED only have one workspace - tenants = TenantService.get_join_tenants(account) + tenants = TenantService.get_join_tenants(account, session=db.session) if len(tenants) == 0: system_features = FeatureService.get_system_features() @@ -157,7 +159,7 @@ class LoginApi(Resource): "data": "workspace not found, please contact system admin to invite you to join in a workspace", } - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(normalized_email) # Create response with cookies instead of returning tokens in body @@ -291,7 +293,7 @@ class EmailCodeLoginApi(Resource): _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() if account: - tenants = TenantService.get_join_tenants(account) + tenants = TenantService.get_join_tenants(account, session=db.session) if not tenants: workspaces = FeatureService.get_system_features().license.workspaces if not workspaces.is_available(): @@ -299,7 +301,7 @@ class EmailCodeLoginApi(Resource): if not FeatureService.get_system_features().is_allow_create_workspace: raise NotAllowedCreateWorkspace() else: - new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session) TenantService.create_tenant_member(new_tenant, account, db.session, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) @@ -311,6 +313,7 @@ class EmailCodeLoginApi(Resource): name=user_email, interface_language=get_valid_language(language), timezone=args.timezone, + session=db.session, ) except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() @@ -319,7 +322,7 @@ class EmailCodeLoginApi(Resource): raise AccountInFreezeError() except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() - token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) + token_pair = AccountService.login(account, session=db.session, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(user_email) # Create response with cookies instead of returning tokens in body @@ -343,7 +346,7 @@ class RefreshTokenApi(Resource): return {"result": "fail", "message": "No refresh token provided"}, 401 try: - new_token_pair = AccountService.refresh_token(refresh_token) + new_token_pair = AccountService.refresh_token(refresh_token, session=db.session) # Create response with new cookies response = make_response({"result": "success"}) @@ -358,22 +361,22 @@ class RefreshTokenApi(Resource): def _get_account_with_case_fallback(email: str): - account = AccountService.get_user_through_email(email) + account = AccountService.get_user_through_email(email, session=db.session) if account or email == email.lower(): return account - return AccountService.get_user_through_email(email.lower()) + return AccountService.get_user_through_email(email.lower(), session=db.session) def _authenticate_account_with_case_fallback( original_email: str, normalized_email: str, password: str, invite_token: str | None ): try: - return AccountService.authenticate(original_email, password, invite_token) + return AccountService.authenticate(original_email, password, invite_token, session=db.session) except services.errors.account.AccountPasswordError: if original_email == normalized_email: raise - return AccountService.authenticate(normalized_email, password, invite_token) + return AccountService.authenticate(normalized_email, password, invite_token, session=db.session) def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None: diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 670d1c7818d..65f3a5addde 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -195,7 +195,7 @@ class OAuthCallback(Resource): db.session.commit() try: - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account, session=db.session) except Unauthorized: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") except WorkSpaceNotAllowedCreateError: @@ -206,6 +206,7 @@ class OAuthCallback(Resource): token_pair = AccountService.login( account=account, + session=db.session, ip_address=extract_remote_ip(request), ) @@ -240,12 +241,12 @@ def _generate_account( oauth_new_user = False if account: - tenants = TenantService.get_join_tenants(account) + tenants = TenantService.get_join_tenants(account, session=db.session) if not tenants: if not FeatureService.get_system_features().is_allow_create_workspace: raise WorkSpaceNotAllowedCreateError() else: - new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session) TenantService.create_tenant_member(new_tenant, account, db.session, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) @@ -272,9 +273,10 @@ def _generate_account( provider=provider, language=interface_language, timezone=timezone, + session=db.session, ) # Link account - AccountService.link_account_integrate(provider, user_info.id, account) + AccountService.link_account_integrate(provider, user_info.id, account, session=db.session) return account, oauth_new_user diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index c1fa1378ffa..fd5b003b523 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -175,7 +175,7 @@ class InstalledAppsListApi(Resource): if current_user.current_tenant is None: raise ValueError("current_user.current_tenant must not be None") - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) + current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant, session=db.session) installed_app_list: list[dict[str, Any]] = [] for installed_app, app_model in installed_apps: installed_app_list.append( diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index f086bf18622..27f6bcc36dc 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -50,7 +50,7 @@ def get_init_status() -> InitStatusResponse: @only_edition_self_hosted def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse: """Validate initialization password.""" - tenant_count = TenantService.get_tenant_count() + tenant_count = TenantService.get_tenant_count(session=db.session) if tenant_count > 0: raise AlreadySetupError() diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 279e4ec502d..3b5c1bbe18f 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -79,7 +79,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: if get_setup_status(): raise AlreadySetupError() - tenant_count = TenantService.get_tenant_count() + tenant_count = TenantService.get_tenant_count(session=db.session) if tenant_count > 0: raise AlreadySetupError() @@ -94,6 +94,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: password=payload.password, ip_address=extract_remote_ip(request), language=payload.language, + session=db.session, ) return SetupResponse(result="success") diff --git a/api/controllers/console/socketio/workflow.py b/api/controllers/console/socketio/workflow.py index b4f03593fd7..99e56df3cb8 100644 --- a/api/controllers/console/socketio/workflow.py +++ b/api/controllers/console/socketio/workflow.py @@ -4,6 +4,7 @@ from typing import cast from flask import Request as FlaskRequest +from extensions.ext_database import db from extensions.ext_socketio import sio from libs.passport import PassportService from libs.token import extract_access_token @@ -43,7 +44,7 @@ def socket_connect(sid, environ, auth): return False with sio.app.app_context(): - user = AccountService.load_logged_in_account(account_id=user_id) + user = AccountService.load_logged_in_account(account_id=user_id, session=db.session) if not user: logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid) return False diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 0ac26168bc5..c13c8aa162f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -328,7 +328,7 @@ class AccountNameApi(Resource): def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, name=args.name) + updated_account = AccountService.update_account(current_user, session=db.session, name=args.name) return _serialize_account(updated_account) @@ -374,7 +374,7 @@ class AccountAvatarApi(Resource): payload = console_ns.payload or {} args = AccountAvatarPayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, avatar=args.avatar) + updated_account = AccountService.update_account(current_user, session=db.session, avatar=args.avatar) return _serialize_account(updated_account) @@ -391,7 +391,9 @@ class AccountInterfaceLanguageApi(Resource): payload = console_ns.payload or {} args = AccountInterfaceLanguagePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) + updated_account = AccountService.update_account( + current_user, session=db.session, interface_language=args.interface_language + ) return _serialize_account(updated_account) @@ -408,7 +410,9 @@ class AccountInterfaceThemeApi(Resource): payload = console_ns.payload or {} args = AccountInterfaceThemePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) + updated_account = AccountService.update_account( + current_user, session=db.session, interface_theme=args.interface_theme + ) return _serialize_account(updated_account) @@ -425,7 +429,7 @@ class AccountTimezoneApi(Resource): payload = console_ns.payload or {} args = AccountTimezonePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, timezone=args.timezone) + updated_account = AccountService.update_account(current_user, session=db.session, timezone=args.timezone) return _serialize_account(updated_account) @@ -443,7 +447,8 @@ class AccountPasswordApi(Resource): args = AccountPasswordPayload.model_validate(payload) try: - AccountService.update_account_password(current_user, args.password, args.new_password) + assert args.password is not None + AccountService.update_account_password(current_user, args.password, args.new_password, session=db.session) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -731,7 +736,7 @@ class ChangeEmailResetApi(Resource): if AccountService.is_account_in_freeze(normalized_new_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(normalized_new_email): + if not AccountService.check_email_unique(normalized_new_email, session=db.session): raise EmailAlreadyInUseError() reset_data = AccountService.get_change_email_data(args.token) @@ -755,7 +760,9 @@ class ChangeEmailResetApi(Resource): # legitimately verified token. AccountService.revoke_change_email_token(args.token) - updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) + updated_account = AccountService.update_account_email( + current_user, email=normalized_new_email, session=db.session + ) AccountService.send_change_email_completed_notify_email( email=normalized_new_email, @@ -775,6 +782,6 @@ class CheckEmailUnique(Resource): normalized_email = args.email.lower() if AccountService.is_account_in_freeze(normalized_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(normalized_email): + if not AccountService.check_email_unique(normalized_email, session=db.session): raise EmailAlreadyInUseError() return {"result": "success"} diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 4ea77e04b96..3a2e3c92359 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -186,7 +186,7 @@ class MemberListApi(Resource): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") - members = TenantService.get_tenant_members(current_user.current_tenant) + members = TenantService.get_tenant_members(current_user.current_tenant, session=db.session) if dify_config.RBAC_ENABLED: member_ids = [member.id for member in members] member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get( @@ -273,6 +273,7 @@ class MemberInviteEmailApi(Resource): language=interface_language, role=invitee_role, inviter=inviter, + session=db.session, ) encoded_invitee_email = parse.quote(invitee_email) invitation_results.append( @@ -317,7 +318,9 @@ class MemberCancelInviteApi(Resource): abort(404) else: try: - TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + TenantService.remove_member_from_tenant( + current_user.current_tenant, member, current_user, session=db.session + ) except services.errors.account.CannotOperateSelfError as e: return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: @@ -360,7 +363,9 @@ class MemberUpdateRoleApi(Resource): try: assert member is not None, "Member not found" - TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) + TenantService.update_member_role( + current_user.current_tenant, member, new_role, current_user, session=db.session + ) except services.errors.account.CannotOperateSelfError as e: return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: @@ -387,7 +392,7 @@ class DatasetOperatorMemberListApi(Resource): def get(self, current_user: Account): if not current_user.current_tenant: raise ValueError("No current tenant") - members = TenantService.get_dataset_operator_members(current_user.current_tenant) + members = TenantService.get_dataset_operator_members(current_user.current_tenant, session=db.session) member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) response = AccountWithRoleList(accounts=member_models) return response.model_dump(mode="json"), 200 @@ -413,7 +418,7 @@ class SendOwnerTransferEmailApi(Resource): # check if the current user is the owner of the workspace if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_owner(current_user, current_user.current_tenant): + if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session): raise NotOwnerError() if args.language is not None and args.language == "zh-Hans": @@ -448,7 +453,7 @@ class OwnerTransferCheckApi(Resource): # check if the current user is the owner of the workspace if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_owner(current_user, current_user.current_tenant): + if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session): raise NotOwnerError() user_email = current_user.email @@ -494,7 +499,7 @@ class OwnerTransfer(Resource): # check if the current user is the owner of the workspace if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_owner(current_user, current_user.current_tenant): + if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session): raise NotOwnerError() if current_user.id == str(member_id): @@ -516,12 +521,14 @@ class OwnerTransfer(Resource): if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_member(member, current_user.current_tenant): + if not TenantService.is_member(member, current_user.current_tenant, session=db.session): raise MemberNotInTenantError() try: assert member is not None, "Member not found" - TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user) + TenantService.update_member_role( + current_user.current_tenant, member, "owner", current_user, session=db.session + ) AccountService.send_new_owner_transfer_notify_email( account=member, diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 59a33fe0385..0afd7e06bf7 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -325,10 +325,10 @@ class TenantApi(Resource): raise ValueError("No current tenant") if tenant.status == TenantStatus.ARCHIVE: - tenants = TenantService.get_join_tenants(current_user) + tenants = TenantService.get_join_tenants(current_user, session=db.session) # if there is any tenant, switch to the first one if len(tenants) > 0: - TenantService.switch_tenant(current_user, tenants[0].id) + TenantService.switch_tenant(current_user, tenants[0].id, session=db.session) tenant = tenants[0] # else, raise Unauthorized else: @@ -351,7 +351,7 @@ class SwitchWorkspaceApi(Resource): # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args.tenant_id) + TenantService.switch_tenant(current_user, args.tenant_id, session=db.session) except Exception: raise AccountNotLinkTenantError("Account not link tenant") diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index dd93616e6b1..1f25eb576d3 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -47,7 +47,7 @@ class EnterpriseWorkspace(Resource): if account is None: return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session) TenantService.create_tenant_member(tenant, account, db.session, role="owner") tenant_was_created.send(tenant) @@ -84,7 +84,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): def post(self): args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {}) - tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session) tenant_was_created.send(tenant) diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 5653fbae432..49f8fb9656f 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -128,7 +128,7 @@ class WorkspaceSwitchApi(Resource): account = _load_account(auth_data.account_id) try: - TenantService.switch_tenant(account, workspace_id) + TenantService.switch_tenant(account, workspace_id, session=db.session) except AccountNotLinkTenantError: raise NotFound("workspace not found") @@ -152,7 +152,7 @@ class WorkspaceMembersApi(Resource): @accepts(query=MemberListQuery) def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery): tenant = _load_tenant(workspace_id) - members = TenantService.get_tenant_members(tenant) + members = TenantService.get_tenant_members(tenant, session=db.session) total = len(members) start = (query.page - 1) * query.limit page_items = members[start : start + query.limit] @@ -184,6 +184,7 @@ class WorkspaceMembersApi(Resource): language=None, role=body.role, inviter=inviter, + session=db.session, ) except AccountAlreadyInTenantError as exc: raise BadRequest(str(exc)) @@ -232,7 +233,7 @@ class WorkspaceMemberApi(Resource): raise NotFound("member not found") try: - TenantService.remove_member_from_tenant(tenant, member, operator) + TenantService.remove_member_from_tenant(tenant, member, operator, session=db.session) except CannotOperateSelfError as exc: raise BadRequest(str(exc)) except NoPermissionError as exc: @@ -266,7 +267,7 @@ class WorkspaceMemberRoleApi(Resource): raise NotFound("member not found") try: - TenantService.update_member_role(tenant, member, body.role, operator) + TenantService.update_member_role(tenant, member, body.role, operator, session=db.session) except CannotOperateSelfError as exc: raise BadRequest(str(exc)) except NoPermissionError as exc: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 627545d7168..8a57ec9818a 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -10,6 +10,7 @@ from controllers.common.schema import query_params_from_model, register_response from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token +from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.annotation_fields import Annotation, AnnotationList from fields.base import ResponseModel @@ -281,7 +282,9 @@ class AnnotationUpdateDeleteApi(Resource): """Update an existing annotation.""" payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer} - annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id)) + annotation = AppAnnotationService.update_app_annotation_directly( + update_args, app_model.id, str(annotation_id), db.session + ) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json") @@ -310,5 +313,5 @@ class AnnotationUpdateDeleteApi(Resource): @edit_permission_required def delete(self, app_model: App, annotation_id: UUID): """Delete an annotation.""" - AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id)) + AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id), db.session) return "", 204 diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index f68e5a4e6b3..dd173207b09 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -226,7 +226,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_multimodal_documents.append(file_document) doc.attachments = attachments else: - account = AccountService.load_user(document.created_by) + account = AccountService.load_user(document.created_by, db.session) if not account: raise ValueError("Invalid account") doc.attachments = self._get_content_files(doc, current_user=account) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 9c186a9f046..78d8b7dcd53 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -291,7 +291,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): attachments.append(file_document) doc.attachments = attachments else: - account = AccountService.load_user(document.created_by) + account = AccountService.load_user(document.created_by, db.session) if not account: raise ValueError("Invalid account") doc.attachments = self._get_content_files(doc, current_user=account) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 0ae018f6a1d..f6496c70a78 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -84,7 +84,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non if not user_id: raise Unauthorized("Invalid Authorization token.") - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + logged_in_account = AccountService.load_logged_in_account(account_id=user_id, session=db.session) return logged_in_account elif request.blueprint == "openapi": # Account-branch device-flow approval routes (approve / deny / @@ -103,7 +103,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non source = decoded.get("token_source") if source or not user_id: return None - return AccountService.load_logged_in_account(account_id=user_id) + return AccountService.load_logged_in_account(account_id=user_id, session=db.session) elif request.blueprint == "web": app_code = request.headers.get(HEADER_NAME_APP_CODE) webapp_token = extract_webapp_passport(app_code, request) if app_code else None diff --git a/api/libs/password.py b/api/libs/password.py index 3313278492a..4170abbf225 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -16,7 +16,7 @@ def valid_password(password): raise ValueError("Password must contain letters and numbers, and the length must be at least 8 characters.") -def hash_password(password_str, salt_byte): +def hash_password(password_str: str, salt_byte: bytes): dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) return binascii.hexlify(dk) diff --git a/api/services/account_service.py b/api/services/account_service.py index 80411dd288e..7ab757040bb 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1,3 +1,10 @@ +"""Account, workspace, and invitation services. + +Database access in this module is caller-scoped: methods that read or mutate ORM state accept an explicit +``session`` so controllers, tasks, and tests can control transaction lifetime and avoid hidden Flask-scoped session +usage inside service logic. +""" + import base64 import json import logging @@ -235,7 +242,7 @@ class AccountService: ) @staticmethod - def _refresh_account_last_active(account: Account) -> None: + def _refresh_account_last_active(account: Account, session: scoped_session | Session) -> None: now = naive_utc_now() refresh_before = now - ACCOUNT_LAST_ACTIVE_REFRESH_INTERVAL @@ -245,12 +252,12 @@ class AccountService: if not AccountService._should_refresh_account_last_active(account.id): return - db.session.execute( + session.execute( update(Account) .where(Account.id == account.id, Account.last_active_at < refresh_before) .values(last_active_at=now, updated_at=func.current_timestamp()) ) - db.session.commit() + session.commit() @staticmethod def _store_refresh_token(refresh_token: str, account_id: str): @@ -295,20 +302,20 @@ class AccountService: side-effects (current-tenant assignment, commit) are unwanted. ``session`` is injected by the caller so this service stays free - of the Flask-scoped ``db.session`` import. + of a Flask-scoped session import. """ return session.get(Account, account_id) @staticmethod - def load_user(user_id: str) -> None | Account: - account = db.session.get(Account, user_id) + def load_user(user_id: str, session: scoped_session | Session) -> None | Account: + account = session.get(Account, user_id) if not account: return None if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") - current_tenant = db.session.scalar( + current_tenant = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True) .limit(1) @@ -316,7 +323,7 @@ class AccountService: if current_tenant: account.set_tenant_id(current_tenant.tenant_id) else: - available_ta = db.session.scalar( + available_ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) @@ -328,13 +335,13 @@ class AccountService: account.set_tenant_id(available_ta.tenant_id) available_ta.current = True available_ta.last_opened_at = naive_utc_now() - db.session.commit() + session.commit() - AccountService._refresh_account_last_active(account) + AccountService._refresh_account_last_active(account, session) # NOTE: make sure account is accessible outside of a db session # This ensures that it will work correctly after upgrading to Flask version 3.1.2 - db.session.refresh(account) - db.session.close() + session.refresh(account) + session.close() return account @staticmethod @@ -352,10 +359,12 @@ class AccountService: return token @staticmethod - def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: + def authenticate( + email: str, password: str, invite_token: str | None = None, *, session: scoped_session | Session + ) -> Account: """authenticate account with email and password""" - account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) + account = session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: raise AccountPasswordError("Invalid email or password.") @@ -378,12 +387,14 @@ class AccountService: account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() - db.session.commit() + session.commit() return account @staticmethod - def update_account_password(account, password, new_password): + def update_account_password( + account: Account, password: str, new_password: str, *, session: scoped_session | Session + ): """update account password""" if account.password and not compare_password(password, account.password, account.password_salt): raise CurrentPasswordIncorrectError("Current password is incorrect.") @@ -400,8 +411,8 @@ class AccountService: base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - db.session.add(account) - db.session.commit() + session.add(account) + session.commit() return account @staticmethod @@ -413,6 +424,8 @@ class AccountService: interface_theme: str = "light", is_setup: bool | None = False, timezone: str | None = None, + *, + session: scoped_session | Session, ) -> Account: """Create an account, preferring explicit user timezone over language-derived defaults.""" if not FeatureService.get_system_features().is_allow_register and not is_setup: @@ -458,13 +471,19 @@ class AccountService: timezone=resolved_timezone, ) - db.session.add(account) - db.session.commit() + session.add(account) + session.commit() return account @staticmethod def create_account_and_tenant( - email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None + email: str, + name: str, + interface_language: str, + password: str | None = None, + timezone: str | None = None, + *, + session: scoped_session | Session, ) -> Account: """Create an account and owner workspace.""" account = AccountService.create_account( @@ -473,10 +492,11 @@ class AccountService: interface_language=interface_language, password=password, timezone=timezone, + session=session, ) try: - TenantService.create_owner_tenant_if_not_exist(account=account) + TenantService.create_owner_tenant_if_not_exist(account=account, session=session) except Exception: # Enterprise-only side-effect should run independently from personal workspace creation. _try_join_enterprise_default_workspace(str(account.id)) @@ -536,11 +556,11 @@ class AccountService: delete_account_task.delay(account.id) @staticmethod - def link_account_integrate(provider: str, open_id: str, account: Account): + def link_account_integrate(provider: str, open_id: str, account: Account, *, session: scoped_session | Session): """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: AccountIntegrate | None = db.session.scalar( + account_integrate: AccountIntegrate | None = session.scalar( select(AccountIntegrate) .where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider) .limit(1) @@ -556,62 +576,62 @@ class AccountService: account_integrate = AccountIntegrate( account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" ) - db.session.add(account_integrate) + session.add(account_integrate) - db.session.commit() + session.commit() logger.info("Account %s linked %s account %s.", account.id, provider, open_id) except Exception as e: logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id) raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod - def close_account(account: Account): + def close_account(account: Account, *, session: scoped_session | Session): """Close account""" account.status = AccountStatus.CLOSED - db.session.commit() + session.commit() @staticmethod - def update_account(account, **kwargs): + def update_account(account: Account, *, session: scoped_session | Session, **kwargs): """Update account fields""" - account = db.session.merge(account) + account = session.merge(account) for field, value in kwargs.items(): if hasattr(account, field): setattr(account, field, value) else: raise AttributeError(f"Invalid field: {field}") - db.session.commit() + session.commit() return account @staticmethod - def update_account_email(account: Account, email: str) -> Account: + def update_account_email(account: Account, email: str, session: scoped_session | Session) -> Account: """Update account email""" account.email = email - account_integrate = db.session.scalar( + account_integrate = session.scalar( select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1) ) if account_integrate: - db.session.delete(account_integrate) - db.session.add(account) - db.session.commit() + session.delete(account_integrate) + session.add(account) + session.commit() return account @staticmethod - def update_login_info(account: Account, *, ip_address: str): + def update_login_info(account: Account, session: scoped_session | Session, *, ip_address: str): """Update last login time and ip""" account.last_login_at = naive_utc_now() account.last_login_ip = ip_address - db.session.add(account) - db.session.commit() + session.add(account) + session.commit() @staticmethod - def login(account: Account, *, ip_address: str | None = None) -> TokenPair: + def login(account: Account, *, session: scoped_session | Session, ip_address: str | None = None) -> TokenPair: if ip_address: - AccountService.update_login_info(account=account, ip_address=ip_address) + AccountService.update_login_info(account=account, session=session, ip_address=ip_address) if account.status == AccountStatus.PENDING: account.status = AccountStatus.ACTIVE - db.session.commit() + session.commit() access_token = AccountService.get_account_jwt_token(account=account) refresh_token = _generate_refresh_token() @@ -628,13 +648,13 @@ class AccountService: AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) @staticmethod - def refresh_token(refresh_token: str) -> TokenPair: + def refresh_token(refresh_token: str, *, session: scoped_session | Session) -> TokenPair: # Verify the refresh token account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) if not account_id: raise ValueError("Invalid refresh token") - account = AccountService.load_user(account_id.decode("utf-8")) + account = AccountService.load_user(account_id.decode("utf-8"), session) if not account: raise ValueError("Invalid account") @@ -649,8 +669,8 @@ class AccountService: return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token) @staticmethod - def load_logged_in_account(*, account_id: str): - return AccountService.load_user(account_id) + def load_logged_in_account(*, account_id: str, session: scoped_session | Session): + return AccountService.load_user(account_id, session) @classmethod def send_reset_password_email( @@ -1002,7 +1022,7 @@ class AccountService: TokenManager.revoke_token(token, "email_code_login") @classmethod - def get_user_through_email(cls, email: str): + def get_user_through_email(cls, email: str, *, session: scoped_session | Session): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): raise AccountRegisterError( description=( @@ -1011,7 +1031,7 @@ class AccountService: ) ) - account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) + account = session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: return None @@ -1210,13 +1230,19 @@ class AccountService: return False @staticmethod - def check_email_unique(email: str) -> bool: - return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None + def check_email_unique(email: str, *, session: scoped_session | Session) -> bool: + return session.scalar(select(Account).where(Account.email == email).limit(1)) is None class TenantService: @staticmethod - def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant: + def create_tenant( + name: str, + is_setup: bool | None = False, + is_from_dashboard: bool | None = False, + *, + session: scoped_session | Session, + ) -> Tenant: """Create tenant""" if ( not FeatureService.get_system_features().is_allow_create_workspace @@ -1228,8 +1254,8 @@ class TenantService: raise NotAllowedCreateWorkspace() tenant = Tenant(name=name) - db.session.add(tenant) - db.session.commit() + session.add(tenant) + session.commit() for category in TenantPluginAutoUpgradeStrategy.PluginCategory: plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy( @@ -1241,11 +1267,11 @@ class TenantService: exclude_plugins=[], include_plugins=[], ) - db.session.add(plugin_upgrade_strategy) - db.session.commit() + session.add(plugin_upgrade_strategy) + session.commit() tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.commit() + session.commit() from services.credit_pool_service import CreditPoolService @@ -1254,9 +1280,11 @@ class TenantService: return tenant @staticmethod - def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): + def create_owner_tenant_if_not_exist( + account: Account, name: str | None = None, is_setup: bool | None = False, *, session: scoped_session | Session + ): """Check if user have a workspace or not""" - available_ta = db.session.scalar( + available_ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) @@ -1275,10 +1303,10 @@ class TenantService: raise WorkspacesLimitExceededError() if name: - tenant = TenantService.create_tenant(name=name, is_setup=is_setup) + tenant = TenantService.create_tenant(name=name, is_setup=is_setup, session=session) else: - tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) - TenantService.create_tenant_member(tenant, account, db.session, role="owner") + tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup, session=session) + TenantService.create_tenant_member(tenant, account, session, role="owner") if dify_config.RBAC_ENABLED: owner_role_id = AccountService._resolve_legacy_role_id(str(tenant.id), account.id, TenantAccountRole.OWNER) RBACService.MemberRoles.replace( @@ -1288,16 +1316,16 @@ class TenantService: role_ids=[owner_role_id], ) account.current_tenant = tenant - db.session.commit() + session.commit() tenant_was_created.send(tenant) @staticmethod def create_tenant_member( - tenant: Tenant, account: Account, session: scoped_session, role: str = "normal" + tenant: Tenant, account: Account, session: scoped_session | Session, role: str = "normal" ) -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountRole.OWNER: - if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): + if TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=session): logger.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") @@ -1318,10 +1346,10 @@ class TenantService: return ta @staticmethod - def get_join_tenants(account: Account) -> list[Tenant]: + def get_join_tenants(account: Account, *, session: scoped_session | Session) -> list[Tenant]: """Get account join tenants""" return list( - db.session.scalars( + session.scalars( select(Tenant) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) @@ -1340,7 +1368,7 @@ class TenantService: membership + pick the default workspace. ``session`` is injected by the caller so this service stays free - of the Flask-scoped ``db.session`` import. + of a Flask-scoped session import. No tenant-status filter: parity with the legacy controller query (the openapi identity endpoint listed all joined tenants). @@ -1413,7 +1441,7 @@ class TenantService: bearers (no account) collapse to the non-member path. Mirrors the session-injection style of :meth:`account_belongs_to_tenant` rather than :meth:`get_user_role`, which loads full ``Account``/``Tenant`` - objects against the Flask-scoped ``db.session``. + objects against the Flask-scoped session. """ if not account_id: return None @@ -1479,13 +1507,13 @@ class TenantService: ).first() @staticmethod - def get_current_tenant_by_account(account: Account): + def get_current_tenant_by_account(account: Account, *, session: scoped_session | Session): """Get tenant by account and add the role""" tenant = account.current_tenant if not tenant: raise TenantNotFoundError("Tenant not found.") - ta = db.session.scalar( + ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1497,14 +1525,14 @@ class TenantService: return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: str | None = None): + def switch_tenant(account: Account, tenant_id: str | None = None, *, session: scoped_session | Session): """Switch the current workspace for the account""" # Ensure tenant_id is provided if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = db.session.scalar( + tenant_account_join = session.scalar( select(TenantAccountJoin) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) .where( @@ -1518,7 +1546,7 @@ class TenantService: if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - db.session.execute( + session.execute( update(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id) .values(current=False) @@ -1527,10 +1555,10 @@ class TenantService: tenant_account_join.last_opened_at = naive_utc_now() # Set the current tenant for the account account.set_tenant_id(tenant_account_join.tenant_id) - db.session.commit() + session.commit() @staticmethod - def get_tenant_members(tenant: Tenant) -> list[Account]: + def get_tenant_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]: """Get tenant members""" stmt = ( select(Account, TenantAccountJoin.role) @@ -1542,14 +1570,14 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in db.session.execute(stmt): + for account, role in session.execute(stmt): account.role = role updated_accounts.append(account) return updated_accounts @staticmethod - def get_dataset_operator_members(tenant: Tenant) -> list[Account]: + def get_dataset_operator_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]: """Get dataset admin members""" stmt = ( select(Account, TenantAccountJoin.role) @@ -1562,20 +1590,20 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in db.session.execute(stmt): + for account, role in session.execute(stmt): account.role = role updated_accounts.append(account) return updated_accounts @staticmethod - def has_roles(tenant: Tenant, roles: list[TenantAccountRole]) -> bool: + def has_roles(tenant: Tenant, roles: list[TenantAccountRole], *, session: scoped_session | Session) -> bool: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountRole) for role in roles): raise ValueError("all roles must be TenantAccountRole") return ( - db.session.scalar( + session.scalar( select(TenantAccountJoin) .where( TenantAccountJoin.tenant_id == tenant.id, @@ -1587,9 +1615,11 @@ class TenantService: ) @staticmethod - def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: + def get_user_role( + account: Account, tenant: Tenant, *, session: scoped_session | Session + ) -> TenantAccountRole | None: """Get the role of the current account for a given tenant""" - join = db.session.scalar( + join = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1597,12 +1627,14 @@ class TenantService: return TenantAccountRole(join.role) if join else None @staticmethod - def get_tenant_count() -> int: + def get_tenant_count(*, session: scoped_session | Session) -> int: """Get tenant count""" - return cast(int, db.session.scalar(select(func.count(Tenant.id)))) + return cast(int, session.scalar(select(func.count(Tenant.id)))) @staticmethod - def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str): + def check_member_permission( + tenant: Tenant, operator: Account, member: Account | None, action: str, *, session: scoped_session | Session + ): """Check member permission""" if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") @@ -1636,7 +1668,7 @@ class TenantService: "update": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], } - ta_operator = db.session.scalar( + ta_operator = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id) .limit(1) @@ -1646,7 +1678,7 @@ class TenantService: raise NoPermissionError(f"No permission to {action} member.") if action == "remove" and ta_operator.role == TenantAccountRole.ADMIN and member: - ta_member = db.session.scalar( + ta_member = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id) .limit(1) @@ -1655,7 +1687,9 @@ class TenantService: raise NoPermissionError(f"No permission to {action} member.") @staticmethod - def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account): + def remove_member_from_tenant( + tenant: Tenant, account: Account, operator: Account, *, session: scoped_session | Session + ): """Remove member from tenant. Apps and datasets maintained by the removed member are reassigned to @@ -1667,9 +1701,9 @@ class TenantService: if operator.id == account.id: raise CannotOperateSelfError("Cannot operate self.") - TenantService.check_member_permission(tenant, operator, account, "remove") + TenantService.check_member_permission(tenant, operator, account, "remove", session=session) - ta = db.session.scalar( + ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1686,7 +1720,7 @@ class TenantService: if dify_config.RBAC_ENABLED: owner_id = AccountService.get_rbac_workspace_owner_account_id(str(tenant.id), str(operator.id)) else: - owner_id = db.session.scalar( + owner_id = session.scalar( select(TenantAccountJoin.account_id) .where( TenantAccountJoin.tenant_id == tenant.id, @@ -1697,7 +1731,7 @@ class TenantService: if owner_id is None: raise ValueError(f"Workspace owner not found for tenant {tenant.id}.") - db.session.execute( + session.execute( update(App) .where( App.tenant_id == tenant.id, @@ -1705,7 +1739,7 @@ class TenantService: ) .values(maintainer=owner_id) ) - db.session.execute( + session.execute( update(Dataset) .where( Dataset.tenant_id == tenant.id, @@ -1713,23 +1747,23 @@ class TenantService: ) .values(maintainer=owner_id) ) - db.session.delete(ta) + session.delete(ta) # Clean up orphaned pending accounts (invited but never activated) should_delete_account = False if account.status == AccountStatus.PENDING: # autoflush flushes ta deletion before this query, so 0 means no remaining joins remaining_joins = ( - db.session.scalar( + session.scalar( select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id) ) or 0 ) if remaining_joins == 0: - db.session.delete(account) + session.delete(account) should_delete_account = True - db.session.commit() + session.commit() if should_delete_account: logger.info( @@ -1755,12 +1789,14 @@ class TenantService: ) @staticmethod - def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): + def update_member_role( + tenant: Tenant, member: Account, new_role: str, operator: Account, *, session: scoped_session | Session + ): """Update member role""" - TenantService.check_member_permission(tenant, operator, member, "update") + TenantService.check_member_permission(tenant, operator, member, "update", session=session) new_tenant_role = TenantAccountRole(new_role) - target_member_join = db.session.scalar( + target_member_join = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id) .limit(1) @@ -1769,7 +1805,7 @@ class TenantService: if not target_member_join: raise MemberNotInTenantError("Member not in tenant.") - operator_role = TenantService.get_user_role(operator, tenant) + operator_role = TenantService.get_user_role(operator, tenant, session=session) target_role = TenantAccountRole(target_member_join.role) if operator_role == TenantAccountRole.ADMIN and (TenantAccountRole.OWNER in {target_role, new_tenant_role}): raise NoPermissionError("No permission to update member.") @@ -1779,7 +1815,7 @@ class TenantService: if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = db.session.scalar( + current_owner_join = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") .limit(1) @@ -1815,7 +1851,7 @@ class TenantService: ) else: target_member_join.role = new_tenant_role - db.session.commit() + session.commit() @staticmethod def get_custom_config(tenant_id: str): @@ -1824,13 +1860,13 @@ class TenantService: return tenant.custom_config_dict @staticmethod - def is_owner(account: Account, tenant: Tenant) -> bool: - return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER + def is_owner(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool: + return TenantService.get_user_role(account, tenant, session=session) == TenantAccountRole.OWNER @staticmethod - def is_member(account: Account, tenant: Tenant) -> bool: + def is_member(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool: """Check if the account is a member of the tenant""" - return TenantService.get_user_role(account, tenant) is not None + return TenantService.get_user_role(account, tenant, session=session) is not None class RegisterService: @@ -1839,7 +1875,16 @@ class RegisterService: return f"member_invite:token:{token}" @classmethod - def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None): + def setup( + cls, + email: str, + name: str, + password: str, + ip_address: str, + language: str | None, + *, + session: scoped_session | Session, + ): """ Setup dify @@ -1856,22 +1901,23 @@ class RegisterService: interface_language=get_valid_language(language), password=password, is_setup=True, + session=session, ) account.last_login_ip = ip_address account.initialized_at = naive_utc_now() - TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) + TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True, session=session) dify_setup = DifySetup(version=dify_config.project.version) - db.session.add(dify_setup) - db.session.commit() + session.add(dify_setup) + session.commit() except Exception as e: - db.session.execute(delete(DifySetup)) - db.session.execute(delete(TenantAccountJoin)) - db.session.execute(delete(Account)) - db.session.execute(delete(Tenant)) - db.session.commit() + session.execute(delete(DifySetup)) + session.execute(delete(TenantAccountJoin)) + session.execute(delete(Account)) + session.execute(delete(Tenant)) + session.commit() logger.exception("Setup account failed, email: %s, name: %s", email, name) raise ValueError(f"Setup failed: {e}") @@ -1889,9 +1935,11 @@ class RegisterService: is_setup: bool | None = False, create_workspace_required: bool | None = True, timezone: str | None = None, + *, + session: scoped_session | Session, ) -> Account: """Register account""" - db.session.begin_nested() + session.begin_nested() try: interface_language = get_valid_language(language) account = AccountService.create_account( @@ -1901,12 +1949,13 @@ class RegisterService: password=password, is_setup=is_setup, timezone=timezone, + session=session, ) account.status = status or AccountStatus.ACTIVE account.initialized_at = naive_utc_now() if open_id is not None and provider is not None: - AccountService.link_account_integrate(provider, open_id, account) + AccountService.link_account_integrate(provider, open_id, account, session=session) if ( FeatureService.get_system_features().is_allow_create_workspace @@ -1914,27 +1963,27 @@ class RegisterService: and FeatureService.get_system_features().license.workspaces.is_available() ): try: - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, db.session, role="owner") + tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=session) + TenantService.create_tenant_member(tenant, account, session, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) except Exception: _try_join_enterprise_default_workspace(str(account.id)) raise - db.session.commit() + session.commit() _try_join_enterprise_default_workspace(str(account.id)) except WorkSpaceNotAllowedCreateError: - db.session.rollback() + session.rollback() logger.exception("Register failed") raise AccountRegisterError("Workspace is not allowed to create.") except AccountRegisterError as are: - db.session.rollback() + session.rollback() logger.exception("Register failed") raise are except Exception as e: - db.session.rollback() + session.rollback() logger.exception("Register failed") raise AccountRegisterError(f"Registration failed: {e}") from e @@ -1942,7 +1991,14 @@ class RegisterService: @classmethod def invite_new_member( - cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None + cls, + tenant: Tenant, + email: str, + language: str | None, + role: str = "normal", + inviter: Account | None = None, + *, + session: scoped_session | Session, ) -> str: if not inviter: raise ValueError("Inviter is required") @@ -1960,7 +2016,7 @@ class RegisterService: requires_setup = False if not account: - TenantService.check_member_permission(tenant, inviter, None, "add") + TenantService.check_member_permission(tenant, inviter, None, "add", session=session) name = normalized_email.split("@")[0] account = cls.register( @@ -1969,13 +2025,14 @@ class RegisterService: language=language, status=AccountStatus.PENDING, is_setup=True, + session=session, ) - TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role) - TenantService.switch_tenant(account, tenant.id) + TenantService.create_tenant_member(tenant, account, session, tenant_join_role) + TenantService.switch_tenant(account, tenant.id, session=session) requires_setup = True else: - TenantService.check_member_permission(tenant, inviter, account, "add") - ta = db.session.scalar( + TenantService.check_member_permission(tenant, inviter, account, "add", session=session) + ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1983,7 +2040,7 @@ class RegisterService: requires_setup = account.status == AccountStatus.PENDING if not ta and (account.status == AccountStatus.PENDING or dify_config.RBAC_ENABLED): - TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role) + TenantService.create_tenant_member(tenant, account, session, tenant_join_role) # Support resend invitation email when the account is pending status if account.status != AccountStatus.PENDING: @@ -2052,20 +2109,20 @@ class RegisterService: @classmethod def get_invitation_if_token_valid( - cls, workspace_id: str | None, email: str | None, token: str + cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session ) -> InvitationDetailDict | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None - tenant = db.session.scalar( + tenant = session.scalar( select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1) ) if not tenant: return None - account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1)) + account = session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1)) if not account: return None @@ -2105,13 +2162,13 @@ class RegisterService: @classmethod def get_invitation_with_case_fallback( - cls, workspace_id: str | None, email: str | None, token: str + cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session ) -> InvitationDetailDict | None: - invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) + invitation = cls.get_invitation_if_token_valid(workspace_id, email, token, session=session) if invitation or not email or email == email.lower(): return invitation normalized_email = email.lower() - return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token) + return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token, session=session) def _generate_refresh_token(length: int = 64): diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index e1762c686ff..4f69c4b44a9 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -4,6 +4,7 @@ from typing import TypedDict import pandas as pd from sqlalchemy import delete, or_, select, update +from sqlalchemy.orm import scoped_session from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -300,17 +301,19 @@ class AppAnnotationService: return annotation @classmethod - def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str): + def update_app_annotation_directly( + cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str, session: scoped_session + ): # get app info _, current_tenant_id = current_account_with_tenant() - app = db.session.scalar( + app = session.scalar( select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.get(MessageAnnotation, annotation_id) + annotation = session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -326,9 +329,9 @@ class AppAnnotationService: annotation.content = answer annotation.question = question - db.session.commit() + session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = db.session.scalar( + app_annotation_setting = session.scalar( select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) @@ -344,33 +347,33 @@ class AppAnnotationService: return annotation @classmethod - def delete_app_annotation(cls, app_id: str, annotation_id: str): + def delete_app_annotation(cls, app_id: str, annotation_id: str, session: scoped_session): # get app info _, current_tenant_id = current_account_with_tenant() - app = db.session.scalar( + app = session.scalar( select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.get(MessageAnnotation, annotation_id) + annotation = session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") - db.session.delete(annotation) + session.delete(annotation) - annotation_hit_histories = db.session.scalars( + annotation_hit_histories = session.scalars( select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id) ).all() if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: - db.session.delete(annotation_hit_history) + session.delete(annotation_hit_history) - db.session.commit() + session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = db.session.scalar( + app_annotation_setting = session.scalar( select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) diff --git a/api/services/oauth_server.py b/api/services/oauth_server.py index 22648070f01..5f3277c9525 100644 --- a/api/services/oauth_server.py +++ b/api/services/oauth_server.py @@ -91,4 +91,4 @@ class OAuthServerService: user_id_str = user_account_id.decode("utf-8") - return AccountService.load_user(user_id_str) + return AccountService.load_user(user_id_str, db.session) diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 70114a83f0b..180c077b88a 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -36,7 +36,9 @@ class WorkspaceService: feature = FeatureService.get_features(tenant.id, exclude_vector_space=True) can_replace_logo = feature.can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], session=db.session + ): base_url = dify_config.FILES_URL replace_webapp_logo = ( f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 70988eb0a13..ea875e63fe8 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -84,6 +84,7 @@ def setup_account(request) -> Generator[Account, None, None]: password=secrets.token_hex(16), ip_address="localhost", language="en-US", + session=db.session, ) with _CACHED_APP.test_request_context(): diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index d87afb87669..464e0134a2f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -548,6 +548,7 @@ class TestAccountGeneration: provider="github", language="en-US", timezone=None, + session=ANY, ) else: mock_register_service.register.assert_not_called() @@ -581,6 +582,7 @@ class TestAccountGeneration: provider="github", language="en-US", timezone=None, + session=ANY, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) @@ -612,6 +614,7 @@ class TestAccountGeneration: provider="github", language="zh-Hans", timezone="Asia/Shanghai", + session=ANY, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) @@ -643,6 +646,7 @@ class TestAccountGeneration: provider="github", language="zh-Hans", timezone=None, + session=ANY, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @@ -673,7 +677,7 @@ class TestAccountGeneration: assert result == mock_account assert oauth_new_user is False - mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace") + mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace", session=ANY) mock_tenant_service.create_tenant_member.assert_called_once_with( mock_new_tenant, mock_account, ANY, role="owner" ) diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py b/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py index 5fe0f787524..00f605b74c5 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from collections.abc import Callable, Iterator +from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Literal from unittest.mock import patch @@ -12,7 +12,6 @@ from flask import Flask from sqlalchemy.orm import Session from controllers.openapi.auth.data import AuthData -from extensions.ext_database import db from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx from models import Account, Tenant from services.account_service import AccountService, TenantService @@ -46,20 +45,25 @@ def make_account(db_session_with_containers: Session) -> Callable[..., Account]: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) if with_owner_tenant: - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist( + account, name=fake.company(), session=db_session_with_containers + ) return account return _make -def add_tenant_for_account(account: Account, *, role: str = "normal", name: str = "Second WS") -> Tenant: +def add_tenant_for_account( + account: Account, *, session: Session, role: str = "normal", name: str = "Second WS" +) -> Tenant: """Create an additional tenant and join ``account`` to it (real service calls).""" with patch("services.account_service.FeatureService") as mock_feature_service: mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True - tenant = TenantService.create_tenant(name=name) - TenantService.create_tenant_member(tenant, account, db.session, role=role) + tenant = TenantService.create_tenant(name=name, session=session) + TenantService.create_tenant_member(tenant, account, session, role=role) return tenant @@ -93,7 +97,7 @@ def account_auth_context( *, token_id: uuid.UUID, client_id: str = "integration-cli", -) -> Iterator[AuthContext]: +) -> Generator[AuthContext]: """Publish an account ``AuthContext`` for handlers that read ``get_auth_ctx()``. The auth pipeline normally sets this ContextVar; the integration suite diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py b/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py index 77c812c0b34..7b5bef7b613 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py @@ -4,6 +4,7 @@ from collections.abc import Callable from inspect import unwrap from flask import Flask +from sqlalchemy.orm import Session from controllers.openapi.account import AccountApi from models import Account @@ -34,11 +35,13 @@ class TestAccountInfo: # the only workspace the account belongs to. assert result.default_workspace_id == owner_tenant.id - def test_lists_all_joined_workspaces(self, app: Flask, make_account: Callable[..., Account]) -> None: + def test_lists_all_joined_workspaces( + self, app: Flask, db_session_with_containers: Session, make_account: Callable[..., Account] + ) -> None: account = make_account() owner_tenant = account.current_tenant assert owner_tenant is not None - second = add_tenant_for_account(account, role="normal", name="Second WS") + second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS") api = AccountApi() with app.test_request_context("/openapi/v1/account"): diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py b/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py index 12018c3c67c..93e8927cfef 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py @@ -81,8 +81,9 @@ def _app_and_account(db_session: Session, *, mode: str = "chat") -> tuple[App, A name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session) tenant = account.current_tenant assert tenant is not None app_args = CreateAppParams( diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py b/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py index aed8c415454..18075704325 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py @@ -41,7 +41,7 @@ class TestWorkspacesList: account = make_account() owner_tenant = account.current_tenant assert owner_tenant is not None - second = add_tenant_for_account(account, role="normal", name="Second WS") + second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS") api = WorkspacesApi() with app.test_request_context("/openapi/v1/workspaces"): @@ -90,7 +90,9 @@ class TestWorkspaceSwitch: account = make_account() owner_tenant = account.current_tenant assert owner_tenant is not None - target = add_tenant_for_account(account, role="normal", name="Switch Target") + target = add_tenant_for_account( + account, session=db_session_with_containers, role="normal", name="Switch Target" + ) api = WorkspaceSwitchApi() with app.test_request_context(f"/openapi/v1/workspaces/{target.id}/switch", method="POST"): diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 9da6b04a2c4..c0da09278e3 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -27,8 +27,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -88,8 +89,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -141,8 +143,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -194,8 +197,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -257,8 +261,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -291,8 +296,11 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account1, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company()) tenant1 = account1.current_tenant account2 = AccountService.create_account( @@ -300,8 +308,11 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account2, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company()) tenant2 = account2.current_tenant # Create dataset for tenant1 @@ -367,8 +378,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Don't create any datasets @@ -391,8 +403,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create multiple datasets @@ -452,8 +465,9 @@ class TestKnowledgeRetrievalIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -520,8 +534,9 @@ class TestKnowledgeRetrievalIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset but no documents @@ -568,8 +583,9 @@ class TestKnowledgeRetrievalIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index a2f5370cb76..65a5b0a96bf 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -64,12 +64,13 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) assert account.email == email assert account.status == AccountStatus.ACTIVE # Login with correct password - logged_in = AccountService.authenticate(email, password) + logged_in = AccountService.authenticate(email, password, session=db_session_with_containers) assert logged_in.id == account.id def test_create_account_without_password( @@ -90,6 +91,7 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) assert account.email == email assert account.password is None @@ -115,6 +117,7 @@ class TestAccountService: name=name, interface_language="en-US", password="invalid_new_password", + session=db_session_with_containers, ) def test_create_account_registration_disabled( @@ -135,6 +138,7 @@ class TestAccountService: name=name, interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) def test_create_account_email_in_freeze( @@ -158,6 +162,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) dify_config.BILLING_ENABLED = False # Reset config for other tests @@ -172,7 +177,7 @@ class TestAccountService: email = fake.email() password = generate_valid_password(fake) with pytest.raises(AccountPasswordError): - AccountService.authenticate(email, password) + AccountService.authenticate(email, password, session=db_session_with_containers) def test_authenticate_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -192,6 +197,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Ban the account @@ -200,7 +206,7 @@ class TestAccountService: db_session_with_containers.commit() with pytest.raises(AccountLoginError): - AccountService.authenticate(email, password) + AccountService.authenticate(email, password, session=db_session_with_containers) def test_authenticate_wrong_password(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -221,10 +227,11 @@ class TestAccountService: name=name, interface_language="en-US", password=correct_password, + session=db_session_with_containers, ) with pytest.raises(AccountPasswordError): - AccountService.authenticate(email, wrong_password) + AccountService.authenticate(email, wrong_password, session=db_session_with_containers) def test_authenticate_with_invite_token( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -246,6 +253,7 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) # Authenticate with invite token to set password @@ -253,6 +261,7 @@ class TestAccountService: email, new_password, invite_token="valid_invite_token", + session=db_session_with_containers, ) assert authenticated_account.id == account.id @@ -279,13 +288,14 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) account.status = AccountStatus.PENDING db_session_with_containers.commit() # Authenticate should activate the account - authenticated_account = AccountService.authenticate(email, password) + authenticated_account = AccountService.authenticate(email, password, session=db_session_with_containers) assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None @@ -310,13 +320,16 @@ class TestAccountService: name=name, interface_language="en-US", password=old_password, + session=db_session_with_containers, ) # Update password - updated_account = AccountService.update_account_password(account, old_password, new_password) + updated_account = AccountService.update_account_password( + account, old_password, new_password, session=db_session_with_containers + ) # Verify new password works - authenticated_account = AccountService.authenticate(email, new_password) + authenticated_account = AccountService.authenticate(email, new_password, session=db_session_with_containers) assert authenticated_account.id == account.id def test_update_account_password_wrong_current_password( @@ -341,10 +354,13 @@ class TestAccountService: name=name, interface_language="en-US", password=old_password, + session=db_session_with_containers, ) with pytest.raises(CurrentPasswordIncorrectError): - AccountService.update_account_password(account, wrong_password, new_password) + AccountService.update_account_password( + account, wrong_password, new_password, session=db_session_with_containers + ) def test_update_account_password_invalid_new_password( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -366,11 +382,12 @@ class TestAccountService: name=name, interface_language="en-US", password=old_password, + session=db_session_with_containers, ) # Test with too short password (assuming minimum length validation) with pytest.raises(ValueError): # Password validation error - AccountService.update_account_password(account, old_password, "123") + AccountService.update_account_password(account, old_password, "123", session=db_session_with_containers) def test_create_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -394,6 +411,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) assert account.email == email @@ -427,6 +445,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) def test_create_account_and_tenant_workspace_limit_exceeded( @@ -455,6 +474,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) def test_link_account_integrate_new_provider( @@ -476,10 +496,13 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) # Link with new provider - AccountService.link_account_integrate("new-google", "google_open_id_123", account) + AccountService.link_account_integrate( + "new-google", "google_open_id_123", account, session=db_session_with_containers + ) # Verify integration was created from models import AccountIntegrate @@ -511,13 +534,18 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) # Link with provider first time - AccountService.link_account_integrate("exists-google", "google_open_id_123", account) + AccountService.link_account_integrate( + "exists-google", "google_open_id_123", account, session=db_session_with_containers + ) # Link with same provider but different open_id (should update) - AccountService.link_account_integrate("exists-google", "google_open_id_456", account) + AccountService.link_account_integrate( + "exists-google", "google_open_id_456", account, session=db_session_with_containers + ) # Verify integration was updated from models import AccountIntegrate @@ -547,10 +575,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Close account - AccountService.close_account(account) + AccountService.close_account(account, session=db_session_with_containers) # Verify account status changed @@ -576,10 +605,13 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Update account fields - updated_account = AccountService.update_account(account, name=updated_name, interface_theme="dark") + updated_account = AccountService.update_account( + account, name=updated_name, interface_theme="dark", session=db_session_with_containers + ) assert updated_account.name == updated_name assert updated_account.interface_theme == "dark" @@ -604,10 +636,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) with pytest.raises(AttributeError): - AccountService.update_account(account, invalid_field="value") + AccountService.update_account(account, invalid_field="value", session=db_session_with_containers) def test_update_login_info(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -628,10 +661,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Update login info - AccountService.update_login_info(account, ip_address=ip_address) + AccountService.update_login_info(account, db_session_with_containers, ip_address=ip_address) # Verify login info was updated @@ -659,10 +693,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Login - token_pair = AccountService.login(account, ip_address=ip_address) + token_pair = AccountService.login(account, ip_address=ip_address, session=db_session_with_containers) assert isinstance(token_pair, TokenPair) assert token_pair.access_token == "mock_access_token" @@ -697,13 +732,14 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) account.status = AccountStatus.PENDING db_session_with_containers.commit() # Login should activate the account - token_pair = AccountService.login(account) + token_pair = AccountService.login(account, session=db_session_with_containers) db_session_with_containers.refresh(account) assert account.status == AccountStatus.ACTIVE @@ -727,10 +763,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Login first to get refresh token - token_pair = AccountService.login(account) + token_pair = AccountService.login(account, session=db_session_with_containers) # Logout AccountService.logout(account=account) @@ -761,15 +798,20 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create associated Tenant - TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + account=account, name=tenant_name, is_setup=True, session=db_session_with_containers + ) # Login to get initial tokens - initial_token_pair = AccountService.login(account) + initial_token_pair = AccountService.login(account, session=db_session_with_containers) # Refresh token - new_token_pair = AccountService.refresh_token(initial_token_pair.refresh_token) + new_token_pair = AccountService.refresh_token( + initial_token_pair.refresh_token, session=db_session_with_containers + ) assert isinstance(new_token_pair, TokenPair) assert new_token_pair.access_token == "new_mock_access_token" @@ -782,7 +824,7 @@ class TestAccountService: fake = Faker() invalid_token = fake.uuid4() with pytest.raises(ValueError, match="Invalid refresh token"): - AccountService.refresh_token(invalid_token) + AccountService.refresh_token(invalid_token, session=db_session_with_containers) def test_refresh_token_invalid_account( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -805,10 +847,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Login to get tokens - token_pair = AccountService.login(account) + token_pair = AccountService.login(account, session=db_session_with_containers) # Delete account @@ -817,7 +860,7 @@ class TestAccountService: # Try to refresh token with deleted account with pytest.raises(ValueError, match="Invalid account"): - AccountService.refresh_token(token_pair.refresh_token) + AccountService.refresh_token(token_pair.refresh_token, session=db_session_with_containers) def test_load_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -838,12 +881,15 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create associated Tenant - TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + account=account, name=tenant_name, is_setup=True, session=db_session_with_containers + ) # Load user - loaded_user = AccountService.load_user(account.id) + loaded_user = AccountService.load_user(account.id, db_session_with_containers) assert loaded_user is not None assert loaded_user.id == account.id @@ -855,7 +901,7 @@ class TestAccountService: """ fake = Faker() non_existent_user_id = fake.uuid4() - loaded_user = AccountService.load_user(non_existent_user_id) + loaded_user = AccountService.load_user(non_existent_user_id, db_session_with_containers) assert loaded_user is None def test_load_user_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): @@ -876,6 +922,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Ban the account @@ -884,7 +931,7 @@ class TestAccountService: db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception - AccountService.load_user(account.id) + AccountService.load_user(account.id, db_session_with_containers) def test_get_account_jwt_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -905,6 +952,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate JWT token @@ -939,12 +987,17 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create associated Tenant - TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + account=account, name=tenant_name, is_setup=True, session=db_session_with_containers + ) # Load logged in account - loaded_account = AccountService.load_logged_in_account(account_id=account.id) + loaded_account = AccountService.load_logged_in_account( + account_id=account.id, session=db_session_with_containers + ) assert loaded_account is not None assert loaded_account.id == account.id @@ -969,10 +1022,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Get user through email - found_user = AccountService.get_user_through_email(email) + found_user = AccountService.get_user_through_email(email, session=db_session_with_containers) assert found_user is not None assert found_user.id == account.id @@ -986,7 +1040,7 @@ class TestAccountService: fake = Faker() domain = f"test-{fake.random_letters(10)}.com" non_existent_email = fake.email(domain=domain) - found_user = AccountService.get_user_through_email(non_existent_email) + found_user = AccountService.get_user_through_email(non_existent_email, session=db_session_with_containers) assert found_user is None def test_get_user_through_email_banned_account( @@ -1009,6 +1063,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Ban the account @@ -1017,7 +1072,7 @@ class TestAccountService: db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception - AccountService.get_user_through_email(email) + AccountService.get_user_through_email(email, session=db_session_with_containers) def test_get_user_through_email_in_freeze( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1032,7 +1087,7 @@ class TestAccountService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True with pytest.raises(AccountRegisterError): - AccountService.get_user_through_email(email_in_freeze) + AccountService.get_user_through_email(email_in_freeze, session=db_session_with_containers) # Reset config dify_config.BILLING_ENABLED = False @@ -1055,6 +1110,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) with ( @@ -1092,6 +1148,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate verification code @@ -1122,6 +1179,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate verification code @@ -1152,6 +1210,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate verification code @@ -1206,7 +1265,7 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) assert tenant.name == tenant_name assert tenant.plan == "basic" @@ -1227,7 +1286,7 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = False with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception - TenantService.create_tenant(name=tenant_name) + TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) def test_create_tenant_with_custom_name( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1243,7 +1302,9 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = False # Create tenant with setup flag (should bypass workspace creation restriction) - tenant = TenantService.create_tenant(name=custom_tenant_name, is_setup=True, is_from_dashboard=True) + tenant = TenantService.create_tenant( + name=custom_tenant_name, is_setup=True, is_from_dashboard=True, session=db_session_with_containers + ) assert tenant.name == custom_tenant_name assert tenant.plan == "basic" @@ -1267,12 +1328,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create tenant member @@ -1302,18 +1364,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account1 = AccountService.create_account( email=email1, name=name1, interface_language="en-US", password=password1, + session=db_session_with_containers, ) account2 = AccountService.create_account( email=email2, name=name2, interface_language="en-US", password=password2, + session=db_session_with_containers, ) # Create first owner @@ -1340,12 +1404,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create member with initial role @@ -1379,16 +1444,17 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant1 = TenantService.create_tenant(name=tenant1_name) - tenant2 = TenantService.create_tenant(name=tenant2_name) + tenant1 = TenantService.create_tenant(name=tenant1_name, session=db_session_with_containers) + tenant2 = TenantService.create_tenant(name=tenant2_name, session=db_session_with_containers) # Add account to both tenants TenantService.create_tenant_member(tenant1, account, db_session_with_containers, role="normal") TenantService.create_tenant_member(tenant2, account, db_session_with_containers, role="admin") # Get join tenants - join_tenants = TenantService.get_join_tenants(account) + join_tenants = TenantService.get_join_tenants(account, session=db_session_with_containers) assert len(join_tenants) == 2 tenant_names = [tenant.name for tenant in join_tenants] @@ -1417,8 +1483,9 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Add account to tenant and set as current TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="owner") @@ -1427,7 +1494,7 @@ class TestTenantService: db_session_with_containers.commit() # Get current tenant - current_tenant = TenantService.get_current_tenant_by_account(account) + current_tenant = TenantService.get_current_tenant_by_account(account, session=db_session_with_containers) assert current_tenant.id == tenant.id assert current_tenant.name == tenant.name @@ -1454,11 +1521,12 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Try to get current tenant (should fail) with pytest.raises((AttributeError, TenantNotFoundError)): - TenantService.get_current_tenant_by_account(account) + TenantService.get_current_tenant_by_account(account, session=db_session_with_containers) def test_switch_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1481,9 +1549,10 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant1 = TenantService.create_tenant(name=tenant1_name) - tenant2 = TenantService.create_tenant(name=tenant2_name) + tenant1 = TenantService.create_tenant(name=tenant1_name, session=db_session_with_containers) + tenant2 = TenantService.create_tenant(name=tenant2_name, session=db_session_with_containers) # Add account to both tenants TenantService.create_tenant_member(tenant1, account, db_session_with_containers, role="owner") @@ -1495,7 +1564,7 @@ class TestTenantService: db_session_with_containers.commit() # Switch to second tenant - TenantService.switch_tenant(account, tenant2.id) + TenantService.switch_tenant(account, tenant2.id, session=db_session_with_containers) # Verify tenant was switched db_session_with_containers.refresh(account) @@ -1520,11 +1589,12 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Try to switch tenant without providing tenant ID with pytest.raises(ValueError, match="Tenant ID must be provided"): - TenantService.switch_tenant(account, None) + TenantService.switch_tenant(account, None, session=db_session_with_containers) def test_switch_tenant_account_not_member( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1548,12 +1618,13 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Try to switch to tenant where account is not a member with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): - TenantService.switch_tenant(account, tenant.id) + TenantService.switch_tenant(account, tenant.id, session=db_session_with_containers) def test_has_roles_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1573,18 +1644,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) admin_account = AccountService.create_account( email=admin_email, name=admin_name, interface_language="en-US", password=admin_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1594,15 +1667,15 @@ class TestTenantService: # Check if tenant has owner role from models.account import TenantAccountRole - has_owner = TenantService.has_roles(tenant, [TenantAccountRole.OWNER]) + has_owner = TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=db_session_with_containers) assert has_owner is True # Check if tenant has admin role - has_admin = TenantService.has_roles(tenant, [TenantAccountRole.ADMIN]) + has_admin = TenantService.has_roles(tenant, [TenantAccountRole.ADMIN], session=db_session_with_containers) assert has_admin is True # Check if tenant has normal role (should be False) - has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) + has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL], session=db_session_with_containers) assert has_normal is False def test_has_roles_invalid_role_type(self, db_session_with_containers: Session, mock_external_service_dependencies): @@ -1618,11 +1691,11 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Try to check roles with invalid role type with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): - TenantService.has_roles(tenant, [invalid_role]) + TenantService.has_roles(tenant, [invalid_role], session=db_session_with_containers) def test_get_user_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1639,19 +1712,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant with specific role TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="editor") # Get user role - user_role = TenantService.get_user_role(account, tenant) + user_role = TenantService.get_user_role(account, tenant, session=db_session_with_containers) assert user_role == "editor" @@ -1675,18 +1749,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1694,7 +1770,9 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="normal") # Check owner permission to add member (should succeed) - TenantService.check_member_permission(tenant, owner_account, member_account, "add") + TenantService.check_member_permission( + tenant, owner_account, member_account, "add", session=db_session_with_containers + ) def test_check_member_permission_invalid_action( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1714,12 +1792,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant @@ -1727,7 +1806,9 @@ class TestTenantService: # Try to check permission with invalid action with pytest.raises(Exception, match="Invalid action"): - TenantService.check_member_permission(tenant, account, None, invalid_action) + TenantService.check_member_permission( + tenant, account, None, invalid_action, session=db_session_with_containers + ) def test_check_member_permission_operate_self( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1746,12 +1827,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant @@ -1759,7 +1841,9 @@ class TestTenantService: # Try to check permission to operate self with pytest.raises(Exception, match="Cannot operate self"): - TenantService.check_member_permission(tenant, account, account, "remove") + TenantService.check_member_permission( + tenant, account, account, "remove", session=db_session_with_containers + ) def test_remove_member_from_tenant_success( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1781,18 +1865,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1827,7 +1913,9 @@ class TestTenantService: ): mock_sync.return_value = True - TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + TenantService.remove_member_from_tenant( + tenant, member_account, owner_account, session=db_session_with_containers + ) # Verify sync was called mock_sync.assert_called_once_with( @@ -1867,12 +1955,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant @@ -1880,7 +1969,7 @@ class TestTenantService: # Try to remove self with pytest.raises(Exception, match="Cannot operate self"): - TenantService.remove_member_from_tenant(tenant, account, account) + TenantService.remove_member_from_tenant(tenant, account, account, session=db_session_with_containers) def test_remove_member_from_tenant_not_member( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1902,18 +1991,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) non_member_account = AccountService.create_account( email=non_member_email, name=non_member_name, interface_language="en-US", password=non_member_password, + session=db_session_with_containers, ) # Add only owner to tenant @@ -1921,7 +2012,9 @@ class TestTenantService: # Try to remove non-member with pytest.raises(Exception, match="Member not in tenant"): - TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) + TenantService.remove_member_from_tenant( + tenant, non_member_account, owner_account, session=db_session_with_containers + ) def test_update_member_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1941,18 +2034,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1960,7 +2055,9 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="normal") # Update member role - TenantService.update_member_role(tenant, member_account, "admin", owner_account) + TenantService.update_member_role( + tenant, member_account, "admin", owner_account, session=db_session_with_containers + ) # Verify role was updated from models.account import TenantAccountJoin @@ -1990,18 +2087,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2009,7 +2108,9 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="admin") # Update member role to owner - TenantService.update_member_role(tenant, member_account, "owner", owner_account) + TenantService.update_member_role( + tenant, member_account, "owner", owner_account, session=db_session_with_containers + ) # Verify roles were updated correctly from models.account import TenantAccountJoin @@ -2047,18 +2148,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2067,7 +2170,9 @@ class TestTenantService: # Try to update member role to already assigned role with pytest.raises(Exception, match="The provided role is already assigned to the member"): - TenantService.update_member_role(tenant, member_account, "admin", owner_account) + TenantService.update_member_role( + tenant, member_account, "admin", owner_account, session=db_session_with_containers + ) def test_get_tenant_count_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -2083,12 +2188,12 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create multiple tenants - tenant1 = TenantService.create_tenant(name=tenant1_name) - tenant2 = TenantService.create_tenant(name=tenant2_name) - tenant3 = TenantService.create_tenant(name=tenant3_name) + tenant1 = TenantService.create_tenant(name=tenant1_name, session=db_session_with_containers) + tenant2 = TenantService.create_tenant(name=tenant2_name, session=db_session_with_containers) + tenant3 = TenantService.create_tenant(name=tenant3_name, session=db_session_with_containers) # Get tenant count - tenant_count = TenantService.get_tenant_count() + tenant_count = TenantService.get_tenant_count(session=db_session_with_containers) # Should have at least 3 tenants (may be more from other tests) assert tenant_count >= 3 @@ -2118,10 +2223,11 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create owner tenant - TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name, session=db_session_with_containers) # Verify tenant was created and linked from models.account import TenantAccountJoin @@ -2158,15 +2264,18 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - existing_tenant = TenantService.create_tenant(name=existing_tenant_name) + existing_tenant = TenantService.create_tenant(name=existing_tenant_name, session=db_session_with_containers) TenantService.create_tenant_member(existing_tenant, account, db_session_with_containers, role="owner") account.current_tenant = existing_tenant db_session_with_containers.commit() # Try to create owner tenant again (should not create new one) - TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) + TenantService.create_owner_tenant_if_not_exist( + account, name=new_workspace_name, session=db_session_with_containers + ) # Verify no new tenant was created tenant_joins = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).all() @@ -2195,11 +2304,14 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Try to create owner tenant (should fail) with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception - TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + TenantService.create_owner_tenant_if_not_exist( + account, name=workspace_name, session=db_session_with_containers + ) def test_get_tenant_members_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -2222,24 +2334,27 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) admin_account = AccountService.create_account( email=admin_email, name=admin_name, interface_language="en-US", password=admin_password, + session=db_session_with_containers, ) normal_account = AccountService.create_account( email=normal_email, name=normal_name, interface_language="en-US", password=normal_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2248,7 +2363,7 @@ class TestTenantService: TenantService.create_tenant_member(tenant, normal_account, db_session_with_containers, role="normal") # Get tenant members - members = TenantService.get_tenant_members(tenant) + members = TenantService.get_tenant_members(tenant, session=db_session_with_containers) assert len(members) == 3 member_emails = [member.email for member in members] @@ -2288,24 +2403,27 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) dataset_operator_account = AccountService.create_account( email=operator_email, name=operator_name, interface_language="en-US", password=operator_password, + session=db_session_with_containers, ) normal_account = AccountService.create_account( email=normal_email, name=normal_name, interface_language="en-US", password=normal_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2316,7 +2434,7 @@ class TestTenantService: TenantService.create_tenant_member(tenant, normal_account, db_session_with_containers, role="normal") # Get dataset operator members - dataset_operators = TenantService.get_dataset_operator_members(tenant) + dataset_operators = TenantService.get_dataset_operator_members(tenant, session=db_session_with_containers) assert len(dataset_operators) == 1 assert dataset_operators[0].email == operator_email @@ -2336,7 +2454,7 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant with custom config - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Set custom config custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} @@ -2402,6 +2520,7 @@ class TestRegisterService: password=admin_password, ip_address=ip_address, language="en-US", + session=db_session_with_containers, ) # Verify account was created @@ -2450,6 +2569,7 @@ class TestRegisterService: password=admin_password, ip_address=ip_address, language="en-US", + session=db_session_with_containers, ) # Verify no entities were created (rollback worked) @@ -2491,6 +2611,7 @@ class TestRegisterService: name=name, password=password, language=language, + session=db_session_with_containers, ) # Verify account was created @@ -2536,6 +2657,7 @@ class TestRegisterService: open_id=open_id, provider=provider, language=language, + session=db_session_with_containers, ) # Verify account was created @@ -2585,6 +2707,7 @@ class TestRegisterService: password=password, language=language, status=AccountStatus.PENDING, + session=db_session_with_containers, ) # Verify account was created with pending status @@ -2624,6 +2747,7 @@ class TestRegisterService: name=name, password=password, language=language, + session=db_session_with_containers, ) # Verify account was created with no tenant @@ -2665,6 +2789,7 @@ class TestRegisterService: name=name, password=password, language=language, + session=db_session_with_containers, ) # Verify account was created with no tenant @@ -2699,6 +2824,7 @@ class TestRegisterService: password=password, language=language, create_workspace_required=False, + session=db_session_with_containers, ) # Verify account was created @@ -2737,12 +2863,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2757,6 +2884,7 @@ class TestRegisterService: language=language, role="normal", inviter=inviter, + session=db_session_with_containers, ) # Verify token was generated @@ -2803,12 +2931,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2818,6 +2947,7 @@ class TestRegisterService: name=existing_member_name, interface_language="en-US", password=existing_member_password, + session=db_session_with_containers, ) # Mock the email task @@ -2830,6 +2960,7 @@ class TestRegisterService: language=language, role="admin", inviter=inviter, + session=db_session_with_containers, ) assert token is not None @@ -2846,7 +2977,9 @@ class TestRegisterService: ) assert tenant_join is None - invitation = RegisterService.get_invitation_if_token_valid(None, None, token) + invitation = RegisterService.get_invitation_if_token_valid( + None, None, token, session=db_session_with_containers + ) assert invitation is not None assert invitation["account"].id == existing_account.id assert invitation["data"]["role"] == "admin" @@ -2872,12 +3005,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2887,6 +3021,7 @@ class TestRegisterService: name=existing_pending_member_name, interface_language="en-US", password=existing_pending_member_password, + session=db_session_with_containers, ) existing_account.status = AccountStatus.PENDING @@ -2906,6 +3041,7 @@ class TestRegisterService: language=language, role="normal", inviter=inviter, + session=db_session_with_containers, ) # Verify token was generated @@ -2930,7 +3066,7 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Execute invitation without inviter (should fail) with pytest.raises(ValueError, match="Inviter is required"): @@ -2940,6 +3076,7 @@ class TestRegisterService: language=language, role="normal", inviter=None, + session=db_session_with_containers, ) def test_invite_new_member_account_already_in_tenant( @@ -2962,12 +3099,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2977,6 +3115,7 @@ class TestRegisterService: name=already_in_tenant_name, interface_language="en-US", password=already_in_tenant_password, + session=db_session_with_containers, ) existing_account.status = AccountStatus.ACTIVE @@ -2993,6 +3132,7 @@ class TestRegisterService: language=language, role="normal", inviter=inviter, + session=db_session_with_containers, ) def test_generate_invite_token_success( @@ -3011,12 +3151,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Execute token generation @@ -3055,12 +3196,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate a real token @@ -3102,12 +3244,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate a real token @@ -3145,12 +3288,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate a real token @@ -3188,12 +3332,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") @@ -3211,6 +3356,7 @@ class TestRegisterService: workspace_id=tenant.id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result contains expected data @@ -3236,6 +3382,7 @@ class TestRegisterService: workspace_id=workspace_id, email=email, token=invalid_token, + session=db_session_with_containers, ) # Verify result is None @@ -3263,6 +3410,7 @@ class TestRegisterService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create a real token but with non-existent tenant ID @@ -3283,6 +3431,7 @@ class TestRegisterService: workspace_id=invalid_tenant_id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result is None (tenant not found) @@ -3308,12 +3457,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") @@ -3333,6 +3483,7 @@ class TestRegisterService: workspace_id=tenant.id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result is None (account ID mismatch) @@ -3358,12 +3509,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") @@ -3390,6 +3542,7 @@ class TestRegisterService: workspace_id=tenant.id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result is None (tenant not in normal status) diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 21a768e3446..0ee0cb84e75 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -114,8 +114,9 @@ class TestAgentService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index bc75562d159..6c6b9338d7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -81,8 +81,9 @@ class TestAnnotationService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Setup app creation arguments @@ -280,7 +281,9 @@ class TestAnnotationService: "question": fake.sentence(), "answer": fake.text(max_nb_chars=200), } - updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + updated_annotation = AppAnnotationService.update_app_annotation_directly( + updated_args, app.id, annotation.id, db_session_with_containers + ) # Verify annotation was updated correctly assert updated_annotation.id == annotation.id @@ -567,7 +570,7 @@ class TestAnnotationService: annotation_id = annotation.id # Delete the annotation - AppAnnotationService.delete_app_annotation(app.id, annotation_id) + AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers) # Verify annotation was deleted @@ -595,7 +598,7 @@ class TestAnnotationService: # Try to delete annotation with non-existent app with pytest.raises(NotFound, match="App not found"): - AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) + AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id, db_session_with_containers) def test_delete_app_annotation_annotation_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -609,7 +612,7 @@ class TestAnnotationService: # Try to delete non-existent annotation with pytest.raises(NotFound, match="Annotation not found"): - AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) + AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id, db_session_with_containers) def test_enable_app_annotation_success( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1225,7 +1228,9 @@ class TestAnnotationService: "question": fake.sentence(), "answer": fake.text(max_nb_chars=200), } - updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + updated_annotation = AppAnnotationService.update_app_annotation_directly( + updated_args, app.id, annotation.id, db_session_with_containers + ) # Verify annotation was updated correctly assert updated_annotation.id == annotation.id @@ -1295,7 +1300,7 @@ class TestAnnotationService: mock_external_service_dependencies["delete_task"].delay.reset_mock() # Delete the annotation - AppAnnotationService.delete_app_annotation(app.id, annotation_id) + AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers) # Verify annotation was deleted deleted_annotation = ( diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 8bd4069639f..1f88ce90621 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -57,8 +57,9 @@ class TestAPIBasedExtensionService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant return account, tenant diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index ff74ca3039e..cee08c4c33e 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -145,8 +145,11 @@ class TestAppDslService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant app_args = CreateAppParams( name=fake.company(), diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index f8482f99c00..fce0d26c484 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -166,8 +166,9 @@ class TestAppGenerateService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 0f5cd184430..f9df99c5594 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -62,8 +62,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Setup app creation arguments @@ -119,8 +120,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -162,8 +164,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -210,8 +213,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -261,8 +265,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppListParams, AppService, CreateAppParams @@ -344,8 +349,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from models import AppStar @@ -404,8 +410,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams, StarredAppListParams @@ -500,8 +507,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -566,14 +574,18 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + first_account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(first_account, name=fake.company()) tenant = first_account.current_tenant second_account = AccountService.create_account( email=fake.email(), name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) from services.app_service import AppListParams, AppService, CreateAppParams @@ -623,8 +635,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -685,8 +698,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -755,8 +769,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams @@ -807,8 +822,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams @@ -857,8 +873,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -910,8 +927,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -971,8 +989,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1030,8 +1049,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1089,8 +1109,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1139,8 +1160,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1190,8 +1212,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1249,8 +1272,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1287,8 +1311,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1326,8 +1351,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1375,8 +1401,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) # Import here to avoid circular dependency from services.app_service import CreateAppParams @@ -1411,8 +1438,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 6d0d281c6ba..f2d682be3bf 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -98,8 +98,9 @@ class TestMessageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Setup app creation arguments @@ -648,8 +649,11 @@ class TestMessageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + other_account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company()) # Test getting message with different user with pytest.raises(MessageNotExistsError): diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py index 5fa5de6d80d..0969198ecf3 100644 --- a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -4,7 +4,7 @@ from __future__ import annotations import uuid from typing import cast -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from uuid import uuid4 import pytest @@ -172,4 +172,4 @@ class TestOAuthServerServiceTokenOperations: result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") assert result is expected_user - mock_load.assert_called_once_with("user-88") + mock_load.assert_called_once_with("user-88", ANY) diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py index ff76bce416a..9643fb61d44 100644 --- a/api/tests/test_containers_integration_tests/services/test_ops_service.py +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -51,8 +51,9 @@ class TestOpsService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant app_service = AppService() app = app_service.create_app( diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index ad85ac67bc5..cfd1d4e86b4 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -68,8 +68,9 @@ class TestSavedMessageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 3a6d635e63b..c464505ef9e 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -55,6 +55,7 @@ class TestTriggerProviderService: def _create_test_account_and_tenant( self, mock_external_service_dependencies: MockExternalServiceDependencies, + db_session_with_containers: Session, ) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -83,8 +84,9 @@ class TestTriggerProviderService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant assert tenant is not None @@ -164,7 +166,9 @@ class TestTriggerProviderService: - Database state is correctly updated """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -262,7 +266,9 @@ class TestTriggerProviderService: - Merged credentials contain only new values """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -320,7 +326,9 @@ class TestTriggerProviderService: - Original credentials are preserved """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -376,7 +384,9 @@ class TestTriggerProviderService: - UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -434,7 +444,9 @@ class TestTriggerProviderService: - Original subscription state is preserved """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -474,9 +486,8 @@ class TestTriggerProviderService: assert subscription.name == original_name assert subscription.parameters == original_parameters - @pytest.mark.usefixtures("db_session_with_containers") def test_rebuild_trigger_subscription_subscription_not_found( - self, mock_external_service_dependencies: MockExternalServiceDependencies + self, mock_external_service_dependencies: MockExternalServiceDependencies, db_session_with_containers: Session ) -> None: """ Test error when subscription is not found. @@ -485,7 +496,9 @@ class TestTriggerProviderService: - Proper error is raised when subscription doesn't exist """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") fake_subscription_id = fake.uuid4() @@ -509,7 +522,9 @@ class TestTriggerProviderService: - Error is raised when new name conflicts with existing subscription """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 8651636616c..664c1167994 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -72,8 +72,9 @@ class TestWebConversationService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 52b12293027..a7e29045e0d 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -63,8 +63,9 @@ class TestWebhookService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant assert tenant is not None diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index fbbf255c581..cf76afb303c 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -78,8 +78,9 @@ class TestWorkflowAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -126,8 +127,9 @@ class TestWorkflowAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant return tenant, account diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index e065e5df1c3..726c360d77e 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -74,8 +74,9 @@ class TestWorkflowRunService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data @@ -530,8 +531,9 @@ class TestWorkflowRunService: name="Test User", password="password123", interface_language="en-US", + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant") + TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers) tenant = account.current_tenant # Create app @@ -581,8 +583,9 @@ class TestWorkflowRunService: name="Test User", password="password123", interface_language="en-US", + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant") + TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers) tenant = account.current_tenant # Create app @@ -632,8 +635,9 @@ class TestWorkflowRunService: name="Test User", password="password123", interface_language="en-US", + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant") + TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers) tenant = account.current_tenant # Create app diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 9b574fe2dfe..6f342e63dc8 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -89,8 +89,9 @@ class TestWorkflowToolManageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index ef65b905086..f6e03b84c90 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -90,8 +90,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -211,8 +212,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -255,8 +257,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Test different index types @@ -342,8 +345,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -424,8 +428,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -525,8 +530,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -625,8 +631,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -717,8 +724,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -820,8 +828,11 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant accounts.append(account) tenants.append(tenant) @@ -926,8 +937,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -1031,8 +1043,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset with built-in fields enabled diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index aba2458d55c..fbdd81aed5d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -67,8 +67,9 @@ class TestDealDatasetVectorIndexTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant assert tenant is not None return account, tenant diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index 6ef7f442591..ebae7de6c15 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -8,7 +8,7 @@ This module tests the account activation mechanism including: - Initial login after activation """ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -41,7 +41,7 @@ class TestActivateCheckApi: } @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation): + def test_check_valid_invitation_token(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock): """ Test checking valid invitation token. @@ -67,7 +67,9 @@ class TestActivateCheckApi: assert response["data"]["email"] == "invitee@example.com" @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_valid_invitation_token_includes_account_status(self, mock_get_invitation, app, mock_invitation): + def test_check_valid_invitation_token_includes_account_status( + self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock + ): mock_account = MagicMock() mock_account.status = AccountStatus.ACTIVE mock_invitation["account"] = mock_account @@ -103,7 +105,9 @@ class TestActivateCheckApi: assert response["is_valid"] is False @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation): + def test_check_token_without_workspace_id( + self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock + ): """ Test checking token without workspace ID. @@ -121,10 +125,10 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is True - mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token") + mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token", session=ANY) @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation): + def test_check_token_without_email(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation): """ Test checking token without email parameter. @@ -142,10 +146,12 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is True - mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token") + mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token", session=ANY) @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation): + def test_check_token_normalizes_email_to_lowercase( + self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock + ): """Ensure token validation uses lowercase emails.""" mock_get_invitation.return_value = mock_invitation @@ -156,7 +162,7 @@ class TestActivateCheckApi: response = api.get() assert response["is_valid"] is True - mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY) class TestActivateApi: @@ -554,7 +560,7 @@ class TestActivateApi: response = api.post() assert response["result"] == "success" - mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY) mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") @patch("controllers.console.auth.activate.TenantService.create_tenant_member") @@ -593,7 +599,7 @@ class TestActivateApi: mock_create_tenant_member.assert_called_once_with( mock_invitation["tenant"], mock_account, mock_db.session, role=TenantAccountRole.ADMIN ) - mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id) + mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY) mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") @patch("controllers.console.auth.activate.TenantService.create_tenant_member") @@ -628,5 +634,5 @@ class TestActivateApi: assert response["result"] == "success" mock_create_tenant_member.assert_not_called() - mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id) + mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY) mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py index df282880af0..14af718590a 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from pydantic import ValidationError @@ -25,6 +25,7 @@ def test_create_new_account_uses_requested_language(mock_create_account): password="ValidPass123!", interface_language="zh-Hans", timezone="Asia/Shanghai", + session=ANY, ) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index fa23942c653..e56ed48bcfe 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -9,7 +9,7 @@ This module tests the email code login mechanism including: """ import base64 -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -368,6 +368,7 @@ class TestEmailCodeLoginApi: name="newuser@example.com", interface_language="en-US", timezone="Asia/Shanghai", + session=ANY, ) @patch("controllers.console.wraps.db") diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 92656357d4d..7a4f583ba83 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -9,7 +9,7 @@ This module tests the core authentication endpoints including: """ import base64 -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import ANY, MagicMock, Mock, patch import pytest from flask import Flask @@ -129,7 +129,7 @@ class TestLoginApi: response = login_api.post() # Assert - mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None) + mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None, session=ANY) mock_login.assert_called_once() mock_reset_rate_limit.assert_called_once_with("test@example.com") assert response.json["result"] == "success" @@ -184,7 +184,7 @@ class TestLoginApi: response = login_api.post() # Assert - mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token") + mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token", session=ANY) assert response.json["result"] == "success" @patch("controllers.console.wraps.db") @@ -407,13 +407,13 @@ class TestLoginApi: @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") def test_login_retries_with_lowercase_email( self, - mock_reset_rate_limit, - mock_login_service, - mock_get_tenants, - mock_add_rate_limit, - mock_authenticate, - mock_get_invitation, - mock_is_rate_limit, + mock_reset_rate_limit: MagicMock, + mock_login_service: MagicMock, + mock_get_tenants: MagicMock, + mock_add_rate_limit: MagicMock, + mock_authenticate: MagicMock, + mock_get_invitation: MagicMock, + mock_is_rate_limit: MagicMock, mock_db, app: Flask, mock_account, @@ -435,8 +435,8 @@ class TestLoginApi: assert response.json["result"] == "success" assert mock_authenticate.call_args_list == [ - (("Upper@Example.com", "ValidPass123!", None), {}), - (("upper@example.com", "ValidPass123!", None), {}), + (("Upper@Example.com", "ValidPass123!", None), {"session": ANY}), + (("upper@example.com", "ValidPass123!", None), {"session": ANY}), ] mock_add_rate_limit.assert_not_called() mock_reset_rate_limit.assert_called_once_with("upper@example.com") @@ -447,10 +447,10 @@ class TestLoginApi: @patch("controllers.console.auth.login._get_account_with_case_fallback") def test_email_code_login_logs_banned_account( self, - mock_get_account, - mock_revoke_token, - mock_get_token_data, - mock_db, + mock_get_account: MagicMock, + mock_revoke_token: MagicMock, + mock_get_token_data: MagicMock, + mock_db: MagicMock, app: Flask, ): mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} @@ -491,7 +491,9 @@ class TestLogoutApi: @patch("controllers.console.auth.login.AccountService.logout") @patch("controllers.console.auth.login.flask_login.logout_user") - def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account): + def test_successful_logout( + self, mock_logout_user: MagicMock, mock_service_logout: MagicMock, app: Flask, mock_account + ): """ Test successful logout flow. diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py index 36c707dbf9b..8b3de6a39e5 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -66,8 +66,9 @@ def test_generate_account_registers_with_browser_timezone( provider="github", language="zh-Hans", timezone="Asia/Shanghai", + session=ANY, ) - mock_link_account.assert_called_once_with("github", "github-123", account) + mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY) @patch("controllers.console.auth.oauth.AccountService.link_account_integrate") @@ -97,8 +98,9 @@ def test_generate_account_prefers_state_language_over_accept_language( provider="github", language="zh-Hans", timezone=None, + session=ANY, ) - mock_link_account.assert_called_once_with("github", "github-123", account) + mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY) @patch("controllers.console.auth.oauth.dify_config") diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index ba69f4d6a78..34fff57b0ad 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -8,7 +8,7 @@ This module tests the token refresh mechanism including: - Error handling for invalid tokens """ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -70,7 +70,7 @@ class TestRefreshTokenApi: # Assert mock_extract_token.assert_called_once() - mock_refresh_token.assert_called_once_with("valid_refresh_token") + mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY) assert response.json["result"] == "success" @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @@ -191,7 +191,7 @@ class TestRefreshTokenApi: # Assert assert response.json["result"] == "success" # Verify new token pair was generated - mock_refresh_token.assert_called_once_with("valid_refresh_token") + mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY) # In real implementation, cookies would be set with new values assert mock_token_pair.access_token == "new_access_token" assert mock_token_pair.refresh_token == "new_refresh_token" diff --git a/api/tests/unit_tests/controllers/console/test_init_validate.py b/api/tests/unit_tests/controllers/console/test_init_validate.py index 4954e0dc96a..88d41fa2bd0 100644 --- a/api/tests/unit_tests/controllers/console/test_init_validate.py +++ b/api/tests/unit_tests/controllers/console/test_init_validate.py @@ -38,7 +38,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None: def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") - monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1) + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 1) app.secret_key = "test-secret" with app.test_request_context("/console/api/init", method="POST"): @@ -48,7 +48,7 @@ def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.Mo def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") - monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0) monkeypatch.setenv("INIT_PASSWORD", "expected") app.secret_key = "test-secret" @@ -60,7 +60,7 @@ def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.M def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") - monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0) monkeypatch.setenv("INIT_PASSWORD", "expected") app.secret_key = "test-secret" diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 1600fcda50d..5f36e805baa 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -1,6 +1,6 @@ import inspect from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -394,12 +394,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_normalize_new_email_before_update( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): current_user = _build_account("old@example.com", "acc3") @@ -424,9 +424,9 @@ class TestChangeEmailReset: method(api, current_user) mock_is_freeze.assert_called_once_with("new@example.com") - mock_check_unique.assert_called_once_with("new@example.com") + mock_check_unique.assert_called_once_with("new@example.com", session=ANY) mock_revoke_token.assert_called_once_with("token-123") - mock_update_account.assert_called_once_with(current_user, email="new@example.com") + mock_update_account.assert_called_once_with(current_user, email="new@example.com", session=ANY) mock_send_notify.assert_called_once_with(email="new@example.com") @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") @@ -437,12 +437,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_reject_reset_when_token_phase_is_not_new_verified( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" @@ -480,12 +480,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_reject_reset_when_token_email_differs_from_payload_new_email( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): """A verified token for address A must not be replayed to change to address B.""" @@ -523,12 +523,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_reject_reset_when_token_account_id_does_not_match_current_user( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): from controllers.console.auth.error import InvalidTokenError @@ -575,9 +575,9 @@ class TestAccountServiceSendChangeEmailEmail: @patch("services.account_service.AccountService.generate_change_email_token") def test_should_bind_account_id_and_target_email_into_generated_token( self, - mock_generate_token, - mock_rate_limiter, - mock_mail_task, + mock_generate_token: MagicMock, + mock_rate_limiter: MagicMock, + mock_mail_task: MagicMock, ): mock_rate_limiter.is_rate_limited.return_value = False mock_generate_token.return_value = "the-token" @@ -665,7 +665,7 @@ class TestAccountDeletionFeedback: class TestCheckEmailUnique: @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask): + def test_should_normalize_email(self, mock_is_freeze: MagicMock, mock_check_unique: MagicMock, app: Flask): mock_is_freeze.return_value = False mock_check_unique.return_value = True @@ -680,7 +680,7 @@ class TestCheckEmailUnique: assert response == {"result": "success"} mock_is_freeze.assert_called_once_with("case@test.com") - mock_check_unique.assert_called_once_with("case@test.com") + mock_check_unique.assert_called_once_with("case@test.com", session=ANY) def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 25ae0778d4b..a6626adc420 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -8,7 +8,7 @@ handler tests use inspect.unwrap() to bypass them and focus on business logic. import inspect from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -115,7 +115,7 @@ class TestEnterpriseWorkspace: assert result["message"] == "enterprise workspace created." assert result["tenant"]["id"] == "tenant-id" assert result["tenant"]["name"] == "My Workspace" - mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY) mock_tenant_svc.create_tenant_member.assert_called_once_with( mock_tenant, mock_account, mock_db.session, role="owner" ) @@ -183,5 +183,5 @@ class TestEnterpriseWorkspaceNoOwnerEmail: assert result["tenant"]["id"] == "tenant-id" assert result["tenant"]["encrypt_public_key"] == "pub-key" assert result["tenant"]["custom_config"] == {} - mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY) mock_event.send.assert_called_once_with(mock_tenant) diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index c748fc0962e..214344e0e57 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -179,7 +179,7 @@ class TestAccountService: mock_password_dependencies["compare_password"].return_value = True # Execute test - result = AccountService.authenticate("test@example.com", "password") + result = AccountService.authenticate("test@example.com", "password", session=mock_db_dependencies["db"].session) # Verify results assert result == mock_account @@ -191,7 +191,11 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password" + AccountPasswordError, + AccountService.authenticate, + "notfound@example.com", + "password", + session=mock_db_dependencies["db"].session, ) def test_authenticate_account_banned(self, mock_db_dependencies): @@ -202,7 +206,13 @@ class TestAccountService: mock_db_dependencies["db"].session.scalar.return_value = mock_account # Execute test and verify exception - self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password") + self._assert_exception_raised( + AccountLoginError, + AccountService.authenticate, + "banned@example.com", + "password", + session=mock_db_dependencies["db"].session, + ) def test_authenticate_password_error(self, mock_db_dependencies, mock_password_dependencies): """Test authentication with wrong password.""" @@ -215,7 +225,11 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - AccountPasswordError, AccountService.authenticate, "test@example.com", "wrongpassword" + AccountPasswordError, + AccountService.authenticate, + "test@example.com", + "wrongpassword", + session=mock_db_dependencies["db"].session, ) def test_authenticate_pending_account_activates(self, mock_db_dependencies, mock_password_dependencies): @@ -228,7 +242,9 @@ class TestAccountService: mock_password_dependencies["compare_password"].return_value = True # Execute test - result = AccountService.authenticate("pending@example.com", "password") + result = AccountService.authenticate( + "pending@example.com", "password", session=mock_db_dependencies["db"].session + ) # Verify results assert result == mock_account @@ -253,6 +269,7 @@ class TestAccountService: interface_language="en-US", password="password123", interface_theme="light", + session=mock_db_dependencies["db"].session, ) # Verify results @@ -290,6 +307,7 @@ class TestAccountService: interface_language="en-US", password="password123", timezone="Asia/Shanghai", + session=mock_db_dependencies["db"].session, ) assert result.timezone == "Asia/Shanghai" @@ -309,6 +327,7 @@ class TestAccountService: email="test@example.com", name="Test User", interface_language="en-US", + session=MagicMock(), ) def test_create_account_email_frozen(self, mock_db_dependencies, mock_external_service_dependencies): @@ -325,6 +344,7 @@ class TestAccountService: email="frozen@example.com", name="Test User", interface_language="en-US", + session=mock_db_dependencies["db"].session, ) dify_config.BILLING_ENABLED = False @@ -341,6 +361,7 @@ class TestAccountService: interface_language="zh-CN", password=None, interface_theme="dark", + session=mock_db_dependencies["db"].session, ) # Verify results @@ -375,7 +396,9 @@ class TestAccountService: mock_password_dependencies["hash_password"].return_value = b"new_hashed_password" # Execute test - result = AccountService.update_account_password(mock_account, "old_password", "new_password123") + result = AccountService.update_account_password( + mock_account, "old_password", "new_password123", session=mock_db_dependencies["db"].session + ) # Verify results assert result == mock_account @@ -391,7 +414,7 @@ class TestAccountService: # Verify database operations self._assert_database_operations_called(mock_db_dependencies["db"]) - def test_update_account_password_current_password_incorrect(self, mock_password_dependencies): + def test_update_account_password_current_password_incorrect(self, mock_db_dependencies, mock_password_dependencies): """Test password update with incorrect current password.""" # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() @@ -404,6 +427,7 @@ class TestAccountService: mock_account, "wrong_password", "new_password123", + session=mock_db_dependencies["db"].session, ) # Verify password comparison was called @@ -411,7 +435,7 @@ class TestAccountService: "wrong_password", "hashed_password", "salt" ) - def test_update_account_password_invalid_new_password(self, mock_password_dependencies): + def test_update_account_password_invalid_new_password(self, mock_db_dependencies, mock_password_dependencies): """Test password update with invalid new password.""" # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() @@ -420,7 +444,12 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - ValueError, AccountService.update_account_password, mock_account, "old_password", "short" + ValueError, + AccountService.update_account_password, + mock_account, + "old_password", + "short", + session=mock_db_dependencies["db"].session, ) # Verify password validation was called @@ -447,19 +476,19 @@ class TestAccountService: mock_datetime.UTC = "UTC" # Execute test - result = AccountService.load_user("user-123") + result = AccountService.load_user("user-123", mock_db_dependencies["db"].session) # Verify results assert result == mock_account assert mock_account.set_tenant_id.called - mock_refresh_last_active.assert_called_once_with(mock_account) + mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session) def test_load_user_not_found(self, mock_db_dependencies): """Test user loading when user does not exist.""" mock_db_dependencies["db"].session.get.return_value = None # Execute test - result = AccountService.load_user("non-existent-user") + result = AccountService.load_user("non-existent-user", mock_db_dependencies["db"].session) # Verify results assert result is None @@ -500,14 +529,14 @@ class TestAccountService: mock_naive_utc_now.return_value = mock_now # Execute test - result = AccountService.load_user("user-123") + result = AccountService.load_user("user-123", mock_db_dependencies["db"].session) # Verify results assert result == mock_account assert mock_available_tenant.current is True assert mock_available_tenant.last_opened_at == mock_now self._assert_database_operations_called(mock_db_dependencies["db"]) - mock_refresh_last_active.assert_called_once_with(mock_account) + mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session) def test_load_user_no_tenants(self, mock_db_dependencies): """Test user loading when user has no tenants at all.""" @@ -525,7 +554,7 @@ class TestAccountService: mock_datetime.UTC = "UTC" # Execute test - result = AccountService.load_user("user-123") + result = AccountService.load_user("user-123", mock_db_dependencies["db"].session) # Verify results assert result is None @@ -542,7 +571,7 @@ class TestAccountService: ): mock_redis_client.set.return_value = True - AccountService._refresh_account_last_active(mock_account) + AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session) mock_redis_client.set.assert_called_once_with( "account_last_active_refresh:user-123", @@ -565,7 +594,7 @@ class TestAccountService: ): mock_redis_client.set.return_value = None - AccountService._refresh_account_last_active(mock_account) + AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session) mock_redis_client.set.assert_called_once_with( "account_last_active_refresh:user-123", @@ -586,7 +615,7 @@ class TestAccountService: patch("services.account_service.naive_utc_now", return_value=now), patch("services.account_service.redis_client") as mock_redis_client, ): - AccountService._refresh_account_last_active(mock_account) + AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session) mock_redis_client.set.assert_not_called() mock_db_dependencies["db"].session.execute.assert_not_called() @@ -736,7 +765,9 @@ class TestTenantService: mock_credit_pool_db.session.commit = MagicMock() # Execute test - TenantService.create_owner_tenant_if_not_exist(mock_account) + TenantService.create_owner_tenant_if_not_exist( + mock_account, session=mock_db_dependencies["db"].session + ) # Verify tenant was created with correct parameters mock_db_dependencies["db"].session.add.assert_called() @@ -838,7 +869,9 @@ class TestTenantService: mock_sync.return_value = True # Act - TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator) + TenantService.remove_member_from_tenant( + mock_tenant, mock_pending_member, mock_operator, session=mock_db.session + ) # Assert: enterprise sync still receives the correct member ID mock_sync.assert_called_once_with( @@ -878,7 +911,9 @@ class TestTenantService: mock_sync.return_value = True # Act - TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator) + TenantService.remove_member_from_tenant( + mock_tenant, mock_pending_member, mock_operator, session=mock_db.session + ) # Assert: only the join record should be deleted, not the account mock_db.session.delete.assert_called_once_with(mock_ta) @@ -909,7 +944,9 @@ class TestTenantService: mock_sync.return_value = True # Act - TenantService.remove_member_from_tenant(mock_tenant, mock_active_member, mock_operator) + TenantService.remove_member_from_tenant( + mock_tenant, mock_active_member, mock_operator, session=mock_db.session + ) # Assert: only the join record should be deleted mock_db.session.delete.assert_called_once_with(mock_ta) @@ -934,7 +971,7 @@ class TestTenantService: mock_naive_utc_now.return_value = mock_now # Execute test - TenantService.switch_tenant(mock_account, "tenant-456") + TenantService.switch_tenant(mock_account, "tenant-456", session=mock_db.session) # Verify tenant was switched assert mock_tenant_join.current is True @@ -947,7 +984,7 @@ class TestTenantService: mock_account = TestAccountAssociatedDataFactory.create_account_mock() # Execute test and verify exception - self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None) + self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None, session=MagicMock()) # ==================== Role Management Tests ==================== @@ -971,7 +1008,7 @@ class TestTenantService: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] # Execute test - TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator) + TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator, session=mock_db.session) # Verify role was updated assert mock_target_join.role == "admin" @@ -1005,7 +1042,9 @@ class TestTenantService: ): mock_db_dependencies["db"].session.scalar.return_value = None - TenantService.create_owner_tenant_if_not_exist(mock_account, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + mock_account, is_setup=True, session=mock_db_dependencies["db"].session + ) mock_rbac_service.MemberRoles.replace.assert_called_once_with( tenant_id="tenant-rbac", @@ -1030,7 +1069,7 @@ class TestTenantService: with patch("services.account_service.db") as mock_db: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] - TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator) + TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session) assert mock_target_join.role == "editor" self._assert_database_operations_called(mock_db) @@ -1052,7 +1091,9 @@ class TestTenantService: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] with pytest.raises(NoPermissionError): - TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator) + TenantService.update_member_role( + mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session + ) def test_admin_cannot_promote_member_to_owner(self): """Test admin cannot promote a non-owner member to owner.""" @@ -1071,7 +1112,9 @@ class TestTenantService: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] with pytest.raises(NoPermissionError): - TenantService.update_member_role(mock_tenant, mock_member, "owner", mock_operator) + TenantService.update_member_role( + mock_tenant, mock_member, "owner", mock_operator, session=mock_db.session + ) # ==================== Permission Check Tests ==================== @@ -1089,7 +1132,9 @@ class TestTenantService: mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join # Execute test - should not raise exception - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "add", session=mock_db_dependencies["db"].session + ) def test_check_member_permission_operate_self(self): """Test member permission check when operator tries to operate self.""" @@ -1108,6 +1153,7 @@ class TestTenantService: mock_operator, mock_operator, # Same as operator "add", + session=MagicMock(), ) def test_admin_can_remove_non_owner_member(self, mock_db_dependencies): @@ -1124,7 +1170,9 @@ class TestTenantService: ) mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join] - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=mock_db_dependencies["db"].session + ) def test_admin_cannot_remove_owner_member(self, mock_db_dependencies): """Test admin cannot remove an owner member.""" @@ -1141,7 +1189,9 @@ class TestTenantService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join] with pytest.raises(NoPermissionError): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_rbac_member_can_remove_non_owner_member(self): """Test RBAC workspace.member.manage allows removing a non-owner member.""" @@ -1158,7 +1208,9 @@ class TestTenantService: patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions), patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=False), ): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_rbac_member_cannot_remove_without_permission(self): """Test RBAC permission check rejects removal without workspace.member.manage.""" @@ -1175,7 +1227,9 @@ class TestTenantService: patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions), ): with pytest.raises(NoPermissionError): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_rbac_member_cannot_remove_owner_member(self): """Test RBAC permission check rejects removing an owner member.""" @@ -1193,7 +1247,9 @@ class TestTenantService: patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=True), ): with pytest.raises(NoPermissionError): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_get_rbac_workspace_owner_account_id(self): mock_roles = MagicMock() @@ -1304,7 +1360,14 @@ class TestRegisterService: mock_dify_setup.return_value = mock_dify_setup_instance # Execute test - RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1", "en-US") + RegisterService.setup( + "admin@example.com", + "Admin User", + "password123", + "192.168.1.1", + "en-US", + session=mock_db_dependencies["db"].session, + ) # Verify results mock_create_account.assert_called_once_with( @@ -1313,8 +1376,11 @@ class TestRegisterService: interface_language="en-US", password="password123", is_setup=True, + session=mock_db_dependencies["db"].session, + ) + mock_create_tenant.assert_called_once_with( + account=mock_account, is_setup=True, session=mock_db_dependencies["db"].session ) - mock_create_tenant.assert_called_once_with(account=mock_account, is_setup=True) mock_dify_setup.assert_called_once() self._assert_database_operations_called(mock_db_dependencies["db"]) @@ -1337,6 +1403,7 @@ class TestRegisterService: "password123", "192.168.1.1", "en-US", + session=mock_db_dependencies["db"].session, ) # Verify rollback operations were called @@ -1369,10 +1436,13 @@ class TestRegisterService: name="Test User", interface_language="en-US", password=None, + session=mock_db_dependencies["db"].session, ) assert result == mock_account - mock_create_workspace.assert_called_once_with(account=mock_account) + mock_create_workspace.assert_called_once_with( + account=mock_account, session=mock_db_dependencies["db"].session + ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled( @@ -1400,9 +1470,12 @@ class TestRegisterService: name="Test User", interface_language="en-US", password=None, + session=mock_db_dependencies["db"].session, ) - mock_create_workspace.assert_called_once_with(account=mock_account) + mock_create_workspace.assert_called_once_with( + account=mock_account, session=mock_db_dependencies["db"].session + ) mock_join_default_workspace.assert_not_called() def test_create_account_and_tenant_still_calls_default_workspace_join_when_workspace_creation_fails( @@ -1433,6 +1506,7 @@ class TestRegisterService: name="Test User", interface_language="en-US", password=None, + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) @@ -1470,6 +1544,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1483,8 +1558,11 @@ class TestRegisterService: password="password123", is_setup=False, timezone=None, + session=mock_db_dependencies["db"].session, + ) + mock_create_tenant.assert_called_once_with( + "Test User's Workspace", session=mock_db_dependencies["db"].session ) - mock_create_tenant.assert_called_once_with("Test User's Workspace") mock_create_member.assert_called_once_with( mock_tenant, mock_account, mock_db_dependencies["db"].session, role="owner" ) @@ -1516,6 +1594,7 @@ class TestRegisterService: password="password123", language="en-US", create_workspace_required=False, + session=mock_db_dependencies["db"].session, ) assert result == mock_account @@ -1546,6 +1625,7 @@ class TestRegisterService: password="password123", language="en-US", create_workspace_required=False, + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_not_called() @@ -1584,6 +1664,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) @@ -1623,6 +1704,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) @@ -1665,11 +1747,14 @@ class TestRegisterService: open_id="oauth123", provider="google", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify results assert result == mock_account - mock_link_account.assert_called_once_with("google", "oauth123", mock_account) + mock_link_account.assert_called_once_with( + "google", "oauth123", mock_account, session=mock_db_dependencies["db"].session + ) self._assert_database_operations_called(mock_db_dependencies["db"]) def test_register_with_pending_status(self, mock_db_dependencies, mock_external_service_dependencies): @@ -1707,6 +1792,7 @@ class TestRegisterService: password="password123", language="en-US", status=AccountStatus.PENDING, + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1744,6 +1830,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify rollback was called @@ -1767,6 +1854,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify rollback was called @@ -1810,6 +1898,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1820,6 +1909,7 @@ class TestRegisterService: language="en-US", status=AccountStatus.PENDING, is_setup=True, + session=mock_db_dependencies["db"].session, ) mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "newuser@example.com") @@ -1856,6 +1946,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) mock_register.assert_called_once_with( @@ -1864,13 +1955,22 @@ class TestRegisterService: language="en-US", status=AccountStatus.PENDING, is_setup=True, + session=mock_db_dependencies["db"].session, ) mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, mixed_email) - mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") + mock_check_permission.assert_called_once_with( + mock_tenant, + mock_inviter, + None, + "add", + session=mock_db_dependencies["db"].session, + ) mock_create_member.assert_called_once_with( mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal" ) - mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) + mock_switch_tenant.assert_called_once_with( + mock_new_account, mock_tenant.id, session=mock_db_dependencies["db"].session + ) mock_generate_token.assert_called_once_with( mock_tenant, mock_new_account, "normal", requires_setup=True ) @@ -1912,6 +2012,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1954,10 +2055,17 @@ class TestRegisterService: language="en-US", role="admin", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "invite-token-123" - mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, mock_existing_account, "add") + mock_check_permission.assert_called_once_with( + mock_tenant, + mock_inviter, + mock_existing_account, + "add", + session=mock_db_dependencies["db"].session, + ) mock_create_member.assert_not_called() mock_generate_token.assert_called_once_with( mock_tenant, mock_existing_account, "admin", requires_setup=False @@ -1993,6 +2101,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) mock_lookup.assert_called_once() @@ -2010,6 +2119,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=None, + session=MagicMock(), ) # ==================== RBAC Member Invitation Tests ==================== @@ -2048,6 +2158,7 @@ class TestRegisterService: language="en-US", role="rbac-role-id-123", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "rbac-token" @@ -2093,6 +2204,7 @@ class TestRegisterService: language="en-US", role="rbac-role-id-456", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "rbac-token" @@ -2141,6 +2253,7 @@ class TestRegisterService: language="en-US", role="rbac-role-id-456", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) mock_create_member.assert_called_once_with( @@ -2191,6 +2304,7 @@ class TestRegisterService: language="en-US", role="editor", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "legacy-token" @@ -2300,7 +2414,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account] # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is not None @@ -2314,7 +2430,9 @@ class TestRegisterService: mock_redis_dependencies.get.return_value = None # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=MagicMock() + ) # Verify results assert result is None @@ -2333,7 +2451,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.return_value = None # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is None @@ -2357,7 +2477,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, None] # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is None @@ -2384,7 +2506,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account] # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is None @@ -2395,22 +2519,28 @@ class TestRegisterService: with patch( "services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation ) as mock_get: - result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + result = RegisterService.get_invitation_with_case_fallback( + "tenant-456", "User@Test.com", "token-123", session=MagicMock() + ) assert result == invitation - mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123") + mock_get.assert_called_once_with( + "tenant-456", "User@Test.com", "token-123", session=mock_get.call_args.kwargs["session"] + ) def test_get_invitation_with_case_fallback_retries_with_lowercase(self): """Fallback helper should retry with lowercase email when needed.""" invitation = {"workspace_id": "tenant-456"} with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get: mock_get.side_effect = [None, invitation] - result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + result = RegisterService.get_invitation_with_case_fallback( + "tenant-456", "User@Test.com", "token-123", session=MagicMock() + ) assert result == invitation assert mock_get.call_args_list == [ - (("tenant-456", "User@Test.com", "token-123"),), - (("tenant-456", "user@test.com", "token-123"),), + (("tenant-456", "User@Test.com", "token-123"), {"session": mock_get.call_args_list[0].kwargs["session"]}), + (("tenant-456", "user@test.com", "token-123"), {"session": mock_get.call_args_list[1].kwargs["session"]}), ] # ==================== Helper Method Tests ==================== diff --git a/api/tests/unit_tests/services/test_annotation_service.py b/api/tests/unit_tests/services/test_annotation_service.py index 55912cc1c1f..7574c13342a 100644 --- a/api/tests/unit_tests/services/test_annotation_service.py +++ b/api/tests/unit_tests/services/test_annotation_service.py @@ -546,7 +546,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1") + AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1", mock_db.session) def test_update_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: """Test missing app raises NotFound in update path.""" @@ -562,7 +562,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1") + AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1", mock_db.session) def test_update_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: """Test missing question raises ValueError.""" @@ -581,7 +581,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(ValueError): - AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session) def test_update_app_annotation_directly_should_update_annotation_and_index(self) -> None: """Test update changes fields and triggers index update.""" @@ -602,7 +602,7 @@ class TestAppAnnotationServiceDirectManipulation: mock_db.session.get.return_value = annotation # Act - result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session) # Assert assert result == annotation @@ -640,7 +640,7 @@ class TestAppAnnotationServiceDirectManipulation: mock_db.session.scalars.return_value = scalars_result # Act - AppAnnotationService.delete_app_annotation(app.id, annotation.id) + AppAnnotationService.delete_app_annotation(app.id, annotation.id, mock_db.session) # Assert mock_db.session.delete.assert_any_call(annotation) @@ -667,7 +667,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.delete_app_annotation("app-1", "ann-1") + AppAnnotationService.delete_app_annotation("app-1", "ann-1", mock_db.session) def test_delete_app_annotation_should_raise_not_found_when_annotation_missing(self) -> None: """Test delete raises NotFound when annotation is missing.""" @@ -684,7 +684,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.delete_app_annotation(app.id, "ann-1") + AppAnnotationService.delete_app_annotation(app.id, "ann-1", mock_db.session) def test_delete_app_annotations_in_batch_should_return_zero_when_none_found(self) -> None: """Test batch delete returns zero when no annotations found."""