dify/api/tests/unit_tests/libs/test_login.py

270 lines
9.2 KiB
Python

from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask, Response, g
from flask_login import UserMixin
from pytest_mock import MockerFixture
import libs.login as login_module
from extensions.ext_login import DifyLoginManager
from libs.login import current_user
from models.account import Account
@pytest.fixture
def protected_view():
"""Build a small login-protected view that exercises the decorator logic."""
@login_module.login_required
def _protected_view():
return "Protected content"
return _protected_view
class MockUser(UserMixin):
"""Mock user class for testing."""
def __init__(self, id: str, is_authenticated: bool = True):
self.id = id
self._is_authenticated = is_authenticated
@property
def is_authenticated(self) -> bool:
return self._is_authenticated
@pytest.fixture
def login_app(mocker: MockerFixture) -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
login_manager = DifyLoginManager()
login_manager.init_app(app)
login_manager.unauthorized = mocker.Mock(
name="unauthorized",
return_value=Response("Unauthorized", status=401, content_type="application/json"),
)
@login_manager.user_loader
def load_user(_user_id: str):
return None
return app
@pytest.fixture(autouse=True)
def reset_login_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", False)
@pytest.fixture
def csrf_check(mocker: MockerFixture) -> MagicMock:
return mocker.patch.object(login_module, "check_csrf_token")
@pytest.fixture
def resolve_current_user(mocker: MockerFixture):
def _patch(user: MockUser | Account | None) -> MagicMock:
return mocker.patch.object(login_module, "_resolve_current_user", return_value=user)
return _patch
class TestLoginRequired:
"""Test cases for login_required decorator."""
def test_authenticated_user_can_access_protected_view(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
resolve_current_user,
):
"""Test that authenticated users can access protected views."""
mock_user = MockUser("test_user", is_authenticated=True)
resolve_user = resolve_current_user(mock_user)
with login_app.test_request_context():
result = protected_view()
csrf_check.assert_called_once()
assert csrf_check.call_args.args[0].method == "GET"
assert csrf_check.call_args.args[1] == "test_user"
assert result == "Protected content"
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_not_called()
@pytest.mark.parametrize(
("resolved_user", "description"),
[
pytest.param(None, "missing user", id="missing-user"),
pytest.param(MockUser("test_user", is_authenticated=False), "unauthenticated user", id="unauthenticated"),
],
)
def test_unauthorized_access_returns_login_manager_response(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
resolve_current_user,
resolved_user: MockUser | None,
description: str,
):
"""Test that missing or unauthenticated users return the manager response."""
resolve_user = resolve_current_user(resolved_user)
with login_app.test_request_context():
result = protected_view()
assert result is login_app.login_manager.unauthorized.return_value, description
assert isinstance(result, Response)
assert result.status_code == 401
resolve_user.assert_called_once_with()
login_app.login_manager.unauthorized.assert_called_once_with()
csrf_check.assert_not_called()
def test_unauthorized_access_propagates_response_object(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
resolve_current_user,
mocker: MockerFixture,
) -> None:
"""Test that unauthorized responses are propagated as Flask Response objects."""
resolve_user = resolve_current_user(None)
response = Response("Unauthorized", status=401, content_type="application/json")
mocker.patch.object(
login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response)
)
with login_app.test_request_context():
result = protected_view()
assert result is response
assert isinstance(result, Response)
resolve_user.assert_called_once_with()
csrf_check.assert_not_called()
@pytest.mark.parametrize(
("method", "login_disabled"),
[
pytest.param("OPTIONS", False, id="options"),
pytest.param("GET", True, id="login-disabled"),
],
)
def test_bypass_paths_skip_authentication_and_csrf(
self,
login_app: Flask,
protected_view,
csrf_check: MagicMock,
monkeypatch: pytest.MonkeyPatch,
resolve_current_user,
method: str,
login_disabled: bool,
):
"""Test that bypass conditions skip auth lookup, CSRF, and unauthorized handling."""
resolve_user = resolve_current_user(MockUser("test_user"))
monkeypatch.setattr(login_module.dify_config, "LOGIN_DISABLED", login_disabled)
with login_app.test_request_context(method=method):
result = protected_view()
assert result == "Protected content"
resolve_user.assert_not_called()
csrf_check.assert_not_called()
login_app.login_manager.unauthorized.assert_not_called()
class TestGetUser:
"""Test cases for _get_user function."""
def test_get_user_returns_user_from_g(self, login_app: Flask):
"""Test that _get_user returns user from g._login_user."""
mock_user = MockUser("test_user")
with login_app.test_request_context():
g._login_user = mock_user
user = login_module._get_user()
assert user == mock_user
assert user.id == "test_user"
def test_get_user_loads_user_if_not_in_g(self, login_app: Flask, mocker: MockerFixture):
"""Test that _get_user loads user if not already in g."""
mock_user = MockUser("test_user")
def load_user_from_request_context() -> None:
g._login_user = mock_user
load_user = mocker.patch.object(
login_app.login_manager,
"load_user_from_request_context",
side_effect=load_user_from_request_context,
)
with login_app.test_request_context():
user = login_module._get_user()
assert user == mock_user
load_user.assert_called_once_with()
def test_get_user_returns_none_without_request_context(self):
"""Test that _get_user returns None outside request context."""
user = login_module._get_user()
assert user is None
class TestCurrentUser:
"""Test cases for current_user proxy."""
def test_current_user_proxy_returns_authenticated_user(self, login_app: Flask, mocker: MockerFixture):
"""Test that current_user proxy returns authenticated user."""
mock_user = MockUser("test_user", is_authenticated=True)
mocker.patch.object(login_module, "_get_user", return_value=mock_user)
with login_app.test_request_context():
assert current_user.id == "test_user"
assert current_user.is_authenticated is True
def test_current_user_proxy_raises_attribute_error_when_no_user(self, login_app: Flask, mocker: MockerFixture):
"""Test that current_user proxy handles None user."""
mocker.patch.object(login_module, "_get_user", return_value=None)
with login_app.test_request_context():
with pytest.raises(AttributeError):
_ = current_user.id
class TestCurrentAccountWithTenant:
"""Test cases for current_account_with_tenant helper."""
def test_returns_account_and_tenant_id(self, mocker: MockerFixture):
account = Account(name="Test User", email="test@example.com")
account._current_tenant = SimpleNamespace(id="tenant-123")
current_user_proxy = mocker.Mock()
current_user_proxy._get_current_object.return_value = account
mocker.patch.object(login_module, "current_user", new=current_user_proxy)
user, tenant_id = login_module.current_account_with_tenant()
assert user is account
assert tenant_id == "tenant-123"
current_user_proxy._get_current_object.assert_called_once_with()
def test_raises_when_current_user_is_not_account(self, mocker: MockerFixture):
mocker.patch.object(login_module, "current_user", new=MockUser("test_user"))
with pytest.raises(ValueError, match="current_user must be an Account instance"):
login_module.current_account_with_tenant()
def test_raises_when_account_has_no_tenant(self, mocker: MockerFixture):
account = Account(name="Test User", email="test@example.com")
mocker.patch.object(login_module, "current_user", new=account)
with pytest.raises(AssertionError, match="tenant information should be loaded"):
login_module.current_account_with_tenant()