mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 21:11:16 +08:00
fix: resolve DetachedInstanceError via session management refactoring (#37847)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4086f5f2d9
commit
2cde7e4a94
@ -25,7 +25,7 @@ def reset_password(email, new_password, password_confirm):
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
@ -67,7 +67,7 @@ def reset_email(email, new_email, email_confirm):
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.auth.error import (
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
@ -100,7 +101,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -175,7 +176,7 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
@ -82,7 +82,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
@ -180,7 +180,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
|
||||
@ -224,7 +224,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, user_info.email)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ def _normalize_invitee_emails(emails: list[str]) -> list[str]:
|
||||
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
|
||||
new_member_count = 0
|
||||
for email in emails:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
if not account:
|
||||
new_member_count += 1
|
||||
continue
|
||||
|
||||
@ -193,7 +193,7 @@ class WorkspaceMembersApi(Resource):
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
normalized_email = body.email.lower()
|
||||
member = AccountService.get_account_by_email_with_case_fallback(normalized_email)
|
||||
member = AccountService.get_account_by_email_with_case_fallback(db.session, normalized_email)
|
||||
if member is None:
|
||||
# invite_new_member just created or fetched this account.
|
||||
raise RuntimeError("invited member missing from DB after invite")
|
||||
|
||||
@ -69,7 +69,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, request_email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
@ -168,7 +168,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
|
||||
@ -14,7 +14,6 @@ from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language, language_timezone_mapping
|
||||
from core.db.session_factory import session_factory
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
@ -981,19 +980,18 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
|
||||
def get_account_by_email_with_case_fallback(session: Session | scoped_session, email: str) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
|
||||
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
@ -1958,7 +1956,7 @@ class RegisterService:
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
requires_setup = False
|
||||
if not account:
|
||||
|
||||
@ -35,7 +35,7 @@ class WebAppAuthService:
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
@ -55,7 +55,7 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
|
||||
@ -270,10 +270,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -165,10 +165,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||
result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -494,10 +494,7 @@ class TestAccountGeneration:
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -57,7 +57,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_get_account.assert_called_once_with(ANY, "User@Example.com")
|
||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||
@ -177,7 +177,7 @@ class TestForgotPasswordResetApi:
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_get_account.assert_called_once_with(ANY, "User@Example.com")
|
||||
mock_update_account.assert_called_once()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
@ -692,12 +692,7 @@ def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():
|
||||
second.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first, second]
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("services.account_service.session_factory", mock_factory):
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||
result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -1821,7 +1821,7 @@ class TestRegisterService:
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with("newuser@example.com")
|
||||
mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "newuser@example.com")
|
||||
|
||||
def test_invite_new_member_normalizes_new_account_email(
|
||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||
@ -1865,7 +1865,7 @@ class TestRegisterService:
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
)
|
||||
mock_lookup.assert_called_once_with(mixed_email)
|
||||
mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, mixed_email)
|
||||
mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add")
|
||||
mock_create_member.assert_called_once_with(
|
||||
mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal"
|
||||
@ -1923,7 +1923,7 @@ class TestRegisterService:
|
||||
mock_tenant, mock_existing_account, "normal", requires_setup=True
|
||||
)
|
||||
mock_task_dependencies.delay.assert_called_once()
|
||||
mock_lookup.assert_called_once_with("existing@example.com")
|
||||
mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "existing@example.com")
|
||||
|
||||
def test_invite_existing_active_account_requires_acceptance_before_joining(
|
||||
self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
|
||||
|
||||
Loading…
Reference in New Issue
Block a user