mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
test: migrate oauth tests to testcontainers (#33973)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a813b9f103
commit
542c1a14e0
@ -1,7 +1,10 @@
|
|||||||
|
"""Testcontainers integration tests for OAuth controller endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from flask import Flask
|
|
||||||
|
|
||||||
from controllers.console.auth.oauth import (
|
from controllers.console.auth.oauth import (
|
||||||
OAuthCallback,
|
OAuthCallback,
|
||||||
@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError
|
|||||||
|
|
||||||
class TestGetOAuthProviders:
|
class TestGetOAuthProviders:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(self):
|
def app(self, flask_app_with_containers):
|
||||||
app = Flask(__name__)
|
return flask_app_with_containers
|
||||||
app.config["TESTING"] = True
|
|
||||||
return app
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("github_config", "google_config", "expected_github", "expected_google"),
|
("github_config", "google_config", "expected_github", "expected_google"),
|
||||||
@ -64,10 +65,8 @@ class TestOAuthLogin:
|
|||||||
return OAuthLogin()
|
return OAuthLogin()
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(self):
|
def app(self, flask_app_with_containers):
|
||||||
app = Flask(__name__)
|
return flask_app_with_containers
|
||||||
app.config["TESTING"] = True
|
|
||||||
return app
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_oauth_provider(self):
|
def mock_oauth_provider(self):
|
||||||
@ -131,10 +130,8 @@ class TestOAuthCallback:
|
|||||||
return OAuthCallback()
|
return OAuthCallback()
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app(self):
|
def app(self, flask_app_with_containers):
|
||||||
app = Flask(__name__)
|
return flask_app_with_containers
|
||||||
app.config["TESTING"] = True
|
|
||||||
return app
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def oauth_setup(self):
|
def oauth_setup(self):
|
||||||
@ -190,15 +187,8 @@ class TestOAuthCallback:
|
|||||||
(KeyError("Missing key"), "OAuth process failed"),
|
(KeyError("Missing key"), "OAuth process failed"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@patch("controllers.console.auth.oauth.db")
|
|
||||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
def test_should_handle_oauth_exceptions(
|
def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
|
||||||
self, mock_get_providers, mock_db, resource, app, exception, expected_error
|
|
||||||
):
|
|
||||||
# Mock database session
|
|
||||||
mock_db.session = MagicMock()
|
|
||||||
mock_db.session.rollback = MagicMock()
|
|
||||||
|
|
||||||
# Import the real requests module to create a proper exception
|
# Import the real requests module to create a proper exception
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@ -258,7 +248,6 @@ class TestOAuthCallback:
|
|||||||
)
|
)
|
||||||
@patch("controllers.console.auth.oauth.AccountService")
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
@patch("controllers.console.auth.oauth.TenantService")
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
@patch("controllers.console.auth.oauth.db")
|
|
||||||
@patch("controllers.console.auth.oauth.dify_config")
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
@patch("controllers.console.auth.oauth._generate_account")
|
@patch("controllers.console.auth.oauth._generate_account")
|
||||||
@ -269,7 +258,6 @@ class TestOAuthCallback:
|
|||||||
mock_generate_account,
|
mock_generate_account,
|
||||||
mock_get_providers,
|
mock_get_providers,
|
||||||
mock_config,
|
mock_config,
|
||||||
mock_db,
|
|
||||||
mock_tenant_service,
|
mock_tenant_service,
|
||||||
mock_account_service,
|
mock_account_service,
|
||||||
resource,
|
resource,
|
||||||
@ -278,10 +266,6 @@ class TestOAuthCallback:
|
|||||||
account_status,
|
account_status,
|
||||||
expected_redirect,
|
expected_redirect,
|
||||||
):
|
):
|
||||||
# Mock database session
|
|
||||||
mock_db.session = MagicMock()
|
|
||||||
mock_db.session.rollback = MagicMock()
|
|
||||||
mock_db.session.commit = MagicMock()
|
|
||||||
|
|
||||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||||
@ -306,14 +290,12 @@ class TestOAuthCallback:
|
|||||||
@patch("controllers.console.auth.oauth.dify_config")
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
@patch("controllers.console.auth.oauth._generate_account")
|
@patch("controllers.console.auth.oauth._generate_account")
|
||||||
@patch("controllers.console.auth.oauth.db")
|
|
||||||
@patch("controllers.console.auth.oauth.TenantService")
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
@patch("controllers.console.auth.oauth.AccountService")
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
def test_should_activate_pending_account(
|
def test_should_activate_pending_account(
|
||||||
self,
|
self,
|
||||||
mock_account_service,
|
mock_account_service,
|
||||||
mock_tenant_service,
|
mock_tenant_service,
|
||||||
mock_db,
|
|
||||||
mock_generate_account,
|
mock_generate_account,
|
||||||
mock_get_providers,
|
mock_get_providers,
|
||||||
mock_config,
|
mock_config,
|
||||||
@ -338,12 +320,10 @@ class TestOAuthCallback:
|
|||||||
|
|
||||||
assert mock_account.status == AccountStatus.ACTIVE
|
assert mock_account.status == AccountStatus.ACTIVE
|
||||||
assert mock_account.initialized_at is not None
|
assert mock_account.initialized_at is not None
|
||||||
mock_db.session.commit.assert_called_once()
|
|
||||||
|
|
||||||
@patch("controllers.console.auth.oauth.dify_config")
|
@patch("controllers.console.auth.oauth.dify_config")
|
||||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||||
@patch("controllers.console.auth.oauth._generate_account")
|
@patch("controllers.console.auth.oauth._generate_account")
|
||||||
@patch("controllers.console.auth.oauth.db")
|
|
||||||
@patch("controllers.console.auth.oauth.TenantService")
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
@patch("controllers.console.auth.oauth.AccountService")
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
@patch("controllers.console.auth.oauth.redirect")
|
@patch("controllers.console.auth.oauth.redirect")
|
||||||
@ -352,7 +332,6 @@ class TestOAuthCallback:
|
|||||||
mock_redirect,
|
mock_redirect,
|
||||||
mock_account_service,
|
mock_account_service,
|
||||||
mock_tenant_service,
|
mock_tenant_service,
|
||||||
mock_db,
|
|
||||||
mock_generate_account,
|
mock_generate_account,
|
||||||
mock_get_providers,
|
mock_get_providers,
|
||||||
mock_config,
|
mock_config,
|
||||||
@ -414,6 +393,10 @@ class TestOAuthCallback:
|
|||||||
|
|
||||||
|
|
||||||
class TestAccountGeneration:
|
class TestAccountGeneration:
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self, flask_app_with_containers):
|
||||||
|
return flask_app_with_containers
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user_info(self):
|
def user_info(self):
|
||||||
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||||
@ -425,15 +408,10 @@ class TestAccountGeneration:
|
|||||||
return account
|
return account
|
||||||
|
|
||||||
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
|
||||||
@patch("controllers.console.auth.oauth.Session")
|
|
||||||
@patch("controllers.console.auth.oauth.Account")
|
@patch("controllers.console.auth.oauth.Account")
|
||||||
@patch("controllers.console.auth.oauth.db")
|
|
||||||
def test_should_get_account_by_openid_or_email(
|
def test_should_get_account_by_openid_or_email(
|
||||||
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account
|
self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
|
||||||
):
|
):
|
||||||
# Mock db.engine for Session creation
|
|
||||||
mock_db.engine = MagicMock()
|
|
||||||
|
|
||||||
# Test OpenID found
|
# Test OpenID found
|
||||||
mock_account_model.get_by_openid.return_value = mock_account
|
mock_account_model.get_by_openid.return_value = mock_account
|
||||||
result = _get_account_by_openid_or_email("github", user_info)
|
result = _get_account_by_openid_or_email("github", user_info)
|
||||||
@ -443,15 +421,14 @@ class TestAccountGeneration:
|
|||||||
|
|
||||||
# Test fallback to email lookup
|
# Test fallback to email lookup
|
||||||
mock_account_model.get_by_openid.return_value = None
|
mock_account_model.get_by_openid.return_value = None
|
||||||
mock_session_instance = MagicMock()
|
|
||||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
|
||||||
mock_get_account.return_value = mock_account
|
mock_get_account.return_value = mock_account
|
||||||
|
|
||||||
result = _get_account_by_openid_or_email("github", user_info)
|
result = _get_account_by_openid_or_email("github", user_info)
|
||||||
assert result == mock_account
|
assert result == mock_account
|
||||||
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance)
|
mock_get_account.assert_called_once()
|
||||||
|
|
||||||
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self):
|
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self):
|
||||||
|
"""Test that case fallback tries lowercase when exact match fails."""
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
first_result = MagicMock()
|
first_result = MagicMock()
|
||||||
first_result.scalar_one_or_none.return_value = None
|
first_result.scalar_one_or_none.return_value = None
|
||||||
@ -462,7 +439,7 @@ class TestAccountGeneration:
|
|||||||
|
|
||||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
|
||||||
|
|
||||||
assert result == expected_account
|
assert result is expected_account
|
||||||
assert mock_session.execute.call_count == 2
|
assert mock_session.execute.call_count == 2
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -478,10 +455,8 @@ class TestAccountGeneration:
|
|||||||
@patch("controllers.console.auth.oauth.RegisterService")
|
@patch("controllers.console.auth.oauth.RegisterService")
|
||||||
@patch("controllers.console.auth.oauth.AccountService")
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
@patch("controllers.console.auth.oauth.TenantService")
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
@patch("controllers.console.auth.oauth.db")
|
|
||||||
def test_should_handle_account_generation_scenarios(
|
def test_should_handle_account_generation_scenarios(
|
||||||
self,
|
self,
|
||||||
mock_db,
|
|
||||||
mock_tenant_service,
|
mock_tenant_service,
|
||||||
mock_account_service,
|
mock_account_service,
|
||||||
mock_register_service,
|
mock_register_service,
|
||||||
@ -519,10 +494,8 @@ class TestAccountGeneration:
|
|||||||
@patch("controllers.console.auth.oauth.RegisterService")
|
@patch("controllers.console.auth.oauth.RegisterService")
|
||||||
@patch("controllers.console.auth.oauth.AccountService")
|
@patch("controllers.console.auth.oauth.AccountService")
|
||||||
@patch("controllers.console.auth.oauth.TenantService")
|
@patch("controllers.console.auth.oauth.TenantService")
|
||||||
@patch("controllers.console.auth.oauth.db")
|
|
||||||
def test_should_register_with_lowercase_email(
|
def test_should_register_with_lowercase_email(
|
||||||
self,
|
self,
|
||||||
mock_db,
|
|
||||||
mock_tenant_service,
|
mock_tenant_service,
|
||||||
mock_account_service,
|
mock_account_service,
|
||||||
mock_register_service,
|
mock_register_service,
|
||||||
Loading…
Reference in New Issue
Block a user