test: migrate oauth tests to testcontainers (#33973)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Desel72 2026-03-24 07:56:40 -05:00 committed by GitHub
parent a813b9f103
commit 542c1a14e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,10 @@
"""Testcontainers integration tests for OAuth controller endpoints."""
from __future__ import annotations
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from flask import Flask
from controllers.console.auth.oauth import ( from controllers.console.auth.oauth import (
OAuthCallback, OAuthCallback,
@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError
class TestGetOAuthProviders: class TestGetOAuthProviders:
@pytest.fixture @pytest.fixture
def app(self): def app(self, flask_app_with_containers):
app = Flask(__name__) return flask_app_with_containers
app.config["TESTING"] = True
return app
@pytest.mark.parametrize( @pytest.mark.parametrize(
("github_config", "google_config", "expected_github", "expected_google"), ("github_config", "google_config", "expected_github", "expected_google"),
@ -64,10 +65,8 @@ class TestOAuthLogin:
return OAuthLogin() return OAuthLogin()
@pytest.fixture @pytest.fixture
def app(self): def app(self, flask_app_with_containers):
app = Flask(__name__) return flask_app_with_containers
app.config["TESTING"] = True
return app
@pytest.fixture @pytest.fixture
def mock_oauth_provider(self): def mock_oauth_provider(self):
@ -131,10 +130,8 @@ class TestOAuthCallback:
return OAuthCallback() return OAuthCallback()
@pytest.fixture @pytest.fixture
def app(self): def app(self, flask_app_with_containers):
app = Flask(__name__) return flask_app_with_containers
app.config["TESTING"] = True
return app
@pytest.fixture @pytest.fixture
def oauth_setup(self): def oauth_setup(self):
@ -190,15 +187,8 @@ class TestOAuthCallback:
(KeyError("Missing key"), "OAuth process failed"), (KeyError("Missing key"), "OAuth process failed"),
], ],
) )
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_handle_oauth_exceptions( def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
self, mock_get_providers, mock_db, resource, app, exception, expected_error
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
# Import the real requests module to create a proper exception # Import the real requests module to create a proper exception
import httpx import httpx
@ -258,7 +248,6 @@ class TestOAuthCallback:
) )
@patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account") @patch("controllers.console.auth.oauth._generate_account")
@ -269,7 +258,6 @@ class TestOAuthCallback:
mock_generate_account, mock_generate_account,
mock_get_providers, mock_get_providers,
mock_config, mock_config,
mock_db,
mock_tenant_service, mock_tenant_service,
mock_account_service, mock_account_service,
resource, resource,
@ -278,10 +266,6 @@ class TestOAuthCallback:
account_status, account_status,
expected_redirect, expected_redirect,
): ):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
mock_db.session.commit = MagicMock()
mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]} mock_get_providers.return_value = {"github": oauth_setup["provider"]}
@ -306,14 +290,12 @@ class TestOAuthCallback:
@patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account") @patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.AccountService")
def test_should_activate_pending_account( def test_should_activate_pending_account(
self, self,
mock_account_service, mock_account_service,
mock_tenant_service, mock_tenant_service,
mock_db,
mock_generate_account, mock_generate_account,
mock_get_providers, mock_get_providers,
mock_config, mock_config,
@ -338,12 +320,10 @@ class TestOAuthCallback:
assert mock_account.status == AccountStatus.ACTIVE assert mock_account.status == AccountStatus.ACTIVE
assert mock_account.initialized_at is not None assert mock_account.initialized_at is not None
mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account") @patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.redirect") @patch("controllers.console.auth.oauth.redirect")
@ -352,7 +332,6 @@ class TestOAuthCallback:
mock_redirect, mock_redirect,
mock_account_service, mock_account_service,
mock_tenant_service, mock_tenant_service,
mock_db,
mock_generate_account, mock_generate_account,
mock_get_providers, mock_get_providers,
mock_config, mock_config,
@ -414,6 +393,10 @@ class TestOAuthCallback:
class TestAccountGeneration: class TestAccountGeneration:
@pytest.fixture
def app(self, flask_app_with_containers):
return flask_app_with_containers
@pytest.fixture @pytest.fixture
def user_info(self): def user_info(self):
return OAuthUserInfo(id="123", name="Test User", email="test@example.com") return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
@ -425,15 +408,10 @@ class TestAccountGeneration:
return account return account
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.oauth.Session")
@patch("controllers.console.auth.oauth.Account") @patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.db")
def test_should_get_account_by_openid_or_email( def test_should_get_account_by_openid_or_email(
self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
): ):
# Mock db.engine for Session creation
mock_db.engine = MagicMock()
# Test OpenID found # Test OpenID found
mock_account_model.get_by_openid.return_value = mock_account mock_account_model.get_by_openid.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info) result = _get_account_by_openid_or_email("github", user_info)
@ -443,15 +421,14 @@ class TestAccountGeneration:
# Test fallback to email lookup # Test fallback to email lookup
mock_account_model.get_by_openid.return_value = None mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_get_account.return_value = mock_account mock_get_account.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info) result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account assert result == mock_account
mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) mock_get_account.assert_called_once()
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self):
"""Test that case fallback tries lowercase when exact match fails."""
mock_session = MagicMock() mock_session = MagicMock()
first_result = MagicMock() first_result = MagicMock()
first_result.scalar_one_or_none.return_value = None first_result.scalar_one_or_none.return_value = None
@ -462,7 +439,7 @@ class TestAccountGeneration:
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session)
assert result == expected_account assert result is expected_account
assert mock_session.execute.call_count == 2 assert mock_session.execute.call_count == 2
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -478,10 +455,8 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_handle_account_generation_scenarios( def test_should_handle_account_generation_scenarios(
self, self,
mock_db,
mock_tenant_service, mock_tenant_service,
mock_account_service, mock_account_service,
mock_register_service, mock_register_service,
@ -519,10 +494,8 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_register_with_lowercase_email( def test_should_register_with_lowercase_email(
self, self,
mock_db,
mock_tenant_service, mock_tenant_service,
mock_account_service, mock_account_service,
mock_register_service, mock_register_service,