mirror of https://github.com/langgenius/dify.git
test: add comprehensive unit tests for login decorator (#22294)
This commit is contained in:
parent
1b26f9a4c6
commit
27e5e2745b
|
|
@ -0,0 +1,232 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask_login import LoginManager, UserMixin
|
||||
|
||||
from libs.login import _get_user, current_user, login_required
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
"""Mock user class for testing."""
|
||||
|
||||
def __init__(self, id: str, is_authenticated: bool = True):
|
||||
self.id = id
|
||||
self._is_authenticated = is_authenticated
|
||||
|
||||
@property
|
||||
def is_authenticated(self):
|
||||
return self._is_authenticated
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_app(self, app: Flask):
|
||||
"""Set up Flask app with login manager."""
|
||||
# Initialize login manager
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
# Mock unauthorized handler
|
||||
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
||||
|
||||
# Add a dummy user loader to prevent exceptions
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock authenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that unauthenticated users are redirected."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Unauthorized"
|
||||
setup_app.login_manager.unauthorized.assert_called_once()
|
||||
|
||||
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
|
||||
"""Test that LOGIN_DISABLED config bypasses authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user and LOGIN_DISABLED
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login.dify_config") as mock_config:
|
||||
mock_config.LOGIN_DISABLED = True
|
||||
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
def test_options_request_bypasses_authentication(self, setup_app: Flask):
|
||||
"""Test that OPTIONS requests are exempt from authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context(method="OPTIONS"):
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
def test_flask_2_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 2.x compatibility with ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Mock Flask 2.x ensure_sync
|
||||
setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Synced content"
|
||||
setup_app.ensure_sync.assert_called_once()
|
||||
|
||||
def test_flask_1_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 1.x compatibility without ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Remove ensure_sync to simulate Flask 1.x
|
||||
if hasattr(setup_app, "ensure_sync"):
|
||||
delattr(setup_app, "ensure_sync")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test cases for _get_user function."""
|
||||
|
||||
def test_get_user_returns_user_from_g(self, app: Flask):
|
||||
"""Test that _get_user returns user from g._login_user."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_user
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
assert user.id == "test_user"
|
||||
|
||||
def test_get_user_loads_user_if_not_in_g(self, app: Flask):
|
||||
"""Test that _get_user loads user if not already in g."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
# Mock login manager
|
||||
login_manager = MagicMock()
|
||||
login_manager._load_user = MagicMock()
|
||||
app.login_manager = login_manager
|
||||
|
||||
with app.test_request_context():
|
||||
# Simulate _load_user setting g._login_user
|
||||
def side_effect():
|
||||
g._login_user = mock_user
|
||||
|
||||
login_manager._load_user.side_effect = side_effect
|
||||
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
login_manager._load_user.assert_called_once()
|
||||
|
||||
def test_get_user_returns_none_without_request_context(self, app: Flask):
|
||||
"""Test that _get_user returns None outside request context."""
|
||||
# Outside of request context
|
||||
user = _get_user()
|
||||
assert user is None
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
"""Test cases for current_user proxy."""
|
||||
|
||||
def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
|
||||
"""Test that current_user proxy returns authenticated user."""
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
assert current_user.id == "test_user"
|
||||
assert current_user.is_authenticated is True
|
||||
|
||||
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
|
||||
"""Test that current_user proxy handles None user."""
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=None):
|
||||
# When _get_user returns None, accessing attributes should fail
|
||||
# or current_user should evaluate to falsy
|
||||
try:
|
||||
# Try to access an attribute that would exist on a real user
|
||||
_ = current_user.id
|
||||
pytest.fail("Should have raised AttributeError")
|
||||
except AttributeError:
|
||||
# This is expected when current_user is None
|
||||
pass
|
||||
|
||||
def test_current_user_proxy_thread_safety(self, app: Flask):
|
||||
"""Test that current_user proxy is thread-safe."""
|
||||
import threading
|
||||
|
||||
results = {}
|
||||
|
||||
def check_user_in_thread(user_id: str, index: int):
|
||||
with app.test_request_context():
|
||||
mock_user = MockUser(user_id)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
results[index] = current_user.id
|
||||
|
||||
# Create multiple threads with different users
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify each thread got its own user
|
||||
for i in range(5):
|
||||
assert results[i] == f"user_{i}"
|
||||
Loading…
Reference in New Issue