diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index dc13143417..915aee3fa7 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1,3 +1,18 @@ +"""Comprehensive unit tests for BillingService. + +This test module covers all aspects of the billing service including: +- HTTP request handling with retry logic +- Subscription tier management and billing information retrieval +- Usage calculation and credit management (positive/negative deltas) +- Rate limit enforcement for compliance downloads and education features +- Account management and permission checks +- Cache management for billing data +- Partner integration features + +All tests use mocking to avoid external dependencies and ensure fast, reliable execution. +Tests follow the Arrange-Act-Assert pattern for clarity. +""" + import json from unittest.mock import MagicMock, patch @@ -5,11 +20,20 @@ import httpx import pytest from werkzeug.exceptions import InternalServerError +from enums.cloud_plan import CloudPlan +from models import Account, TenantAccountJoin, TenantAccountRole from services.billing_service import BillingService class TestBillingServiceSendRequest: - """Unit tests for BillingService._send_request method.""" + """Unit tests for BillingService._send_request method. + + Tests cover: + - Successful GET/PUT/POST/DELETE requests + - Error handling for various HTTP status codes + - Retry logic on network failures + - Request header and parameter validation + """ @pytest.fixture def mock_httpx_request(self): @@ -234,3 +258,1042 @@ class TestBillingServiceSendRequest: # Should retry multiple times (wait=2, stop_before_delay=10 means ~5 attempts) assert mock_httpx_request.call_count > 1 + + +class TestBillingServiceSubscriptionInfo: + """Unit tests for subscription tier and billing info retrieval. + + Tests cover: + - Billing information retrieval + - Knowledge base rate limits with default and custom values + - Payment link generation for subscriptions and model providers + - Invoice retrieval + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_info_success(self, mock_send_request): + """Test successful retrieval of billing information.""" + # Arrange + tenant_id = "tenant-123" + expected_response = { + "subscription_plan": "professional", + "billing_cycle": "monthly", + "status": "active", + } + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_info(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": tenant_id}) + + def test_get_knowledge_rate_limit_with_defaults(self, mock_send_request): + """Test knowledge rate limit retrieval with default values.""" + # Arrange + tenant_id = "tenant-456" + mock_send_request.return_value = {} + + # Act + result = BillingService.get_knowledge_rate_limit(tenant_id) + + # Assert + assert result["limit"] == 10 # Default limit + assert result["subscription_plan"] == CloudPlan.SANDBOX # Default plan + mock_send_request.assert_called_once_with( + "GET", "/subscription/knowledge-rate-limit", params={"tenant_id": tenant_id} + ) + + def test_get_knowledge_rate_limit_with_custom_values(self, mock_send_request): + """Test knowledge rate limit retrieval with custom values.""" + # Arrange + tenant_id = "tenant-789" + mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL} + + # Act + result = BillingService.get_knowledge_rate_limit(tenant_id) + + # Assert + assert result["limit"] == 100 + assert result["subscription_plan"] == CloudPlan.PROFESSIONAL + + def test_get_subscription_payment_link(self, mock_send_request): + """Test subscription payment link generation.""" + # Arrange + plan = "professional" + interval = "monthly" + email = "user@example.com" + tenant_id = "tenant-123" + expected_response = {"payment_link": "https://payment.example.com/checkout"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_subscription(plan, interval, email, tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", + "/subscription/payment-link", + params={"plan": plan, "interval": interval, "prefilled_email": email, "tenant_id": tenant_id}, + ) + + def test_get_model_provider_payment_link(self, mock_send_request): + """Test model provider payment link generation.""" + # Arrange + provider_name = "openai" + tenant_id = "tenant-123" + account_id = "account-456" + email = "user@example.com" + expected_response = {"payment_link": "https://payment.example.com/provider"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_model_provider_payment_link(provider_name, tenant_id, account_id, email) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", + "/model-provider/payment-link", + params={ + "provider_name": provider_name, + "tenant_id": tenant_id, + "account_id": account_id, + "prefilled_email": email, + }, + ) + + def test_get_invoices(self, mock_send_request): + """Test invoice retrieval.""" + # Arrange + email = "user@example.com" + tenant_id = "tenant-123" + expected_response = {"invoices": [{"id": "inv-1", "amount": 100}]} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_invoices(email, tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", "/invoices", params={"prefilled_email": email, "tenant_id": tenant_id} + ) + + +class TestBillingServiceUsageCalculation: + """Unit tests for usage calculation and credit management. + + Tests cover: + - Feature plan usage information retrieval + - Credit addition (positive delta) + - Credit consumption (negative delta) + - Usage refunds + - Specific feature usage queries + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_tenant_feature_plan_usage_info(self, mock_send_request): + """Test retrieval of tenant feature plan usage information.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"features": {"trigger": {"used": 50, "limit": 100}, "workflow": {"used": 20, "limit": 50}}} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/tenant-feature-usage/info", params={"tenant_id": tenant_id}) + + def test_update_tenant_feature_plan_usage_positive_delta(self, mock_send_request): + """Test updating tenant feature usage with positive delta (adding credits).""" + # Arrange + tenant_id = "tenant-123" + feature_key = "trigger" + delta = 10 + expected_response = {"result": "success", "history_id": "hist-uuid-123"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta) + + # Assert + assert result == expected_response + assert result["result"] == "success" + assert "history_id" in result + mock_send_request.assert_called_once_with( + "POST", + "/tenant-feature-usage/usage", + params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta}, + ) + + def test_update_tenant_feature_plan_usage_negative_delta(self, mock_send_request): + """Test updating tenant feature usage with negative delta (consuming credits).""" + # Arrange + tenant_id = "tenant-456" + feature_key = "workflow" + delta = -5 + expected_response = {"result": "success", "history_id": "hist-uuid-456"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "POST", + "/tenant-feature-usage/usage", + params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta}, + ) + + def test_refund_tenant_feature_plan_usage(self, mock_send_request): + """Test refunding a previous usage charge.""" + # Arrange + history_id = "hist-uuid-789" + expected_response = {"result": "success", "history_id": history_id} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.refund_tenant_feature_plan_usage(history_id) + + # Assert + assert result == expected_response + assert result["result"] == "success" + mock_send_request.assert_called_once_with( + "POST", "/tenant-feature-usage/refund", params={"quota_usage_history_id": history_id} + ) + + def test_get_tenant_feature_plan_usage(self, mock_send_request): + """Test getting specific feature usage for a tenant.""" + # Arrange + tenant_id = "tenant-123" + feature_key = "trigger" + expected_response = {"used": 75, "limit": 100, "remaining": 25} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", "/billing/tenant_feature_plan/usage", params={"tenant_id": tenant_id, "feature_key": feature_key} + ) + + +class TestBillingServiceRateLimitEnforcement: + """Unit tests for rate limit enforcement mechanisms. + + Tests cover: + - Compliance download rate limiting (4 requests per 60 seconds) + - Education verification rate limiting (10 requests per 60 seconds) + - Education activation rate limiting (10 requests per 60 seconds) + - Rate limit increment after successful operations + - Proper exception raising when limits are exceeded + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_compliance_download_rate_limiter_not_limited(self, mock_send_request): + """Test compliance download when rate limit is not exceeded.""" + # Arrange + doc_name = "compliance_report.pdf" + account_id = "account-123" + tenant_id = "tenant-456" + ip = "192.168.1.1" + device_info = "Mozilla/5.0" + expected_response = {"download_link": "https://example.com/download"} + + # Mock the rate limiter to return False (not limited) + with ( + patch.object( + BillingService.compliance_download_rate_limiter, "is_rate_limited", return_value=False + ) as mock_is_limited, + patch.object(BillingService.compliance_download_rate_limiter, "increment_rate_limit") as mock_increment, + ): + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info) + + # Assert + assert result == expected_response + mock_is_limited.assert_called_once_with(f"{account_id}:{tenant_id}") + mock_send_request.assert_called_once_with( + "POST", + "/compliance/download", + json={ + "doc_name": doc_name, + "account_id": account_id, + "tenant_id": tenant_id, + "ip_address": ip, + "device_info": device_info, + }, + ) + # Verify rate limit was incremented after successful download + mock_increment.assert_called_once_with(f"{account_id}:{tenant_id}") + + def test_compliance_download_rate_limiter_exceeded(self, mock_send_request): + """Test compliance download when rate limit is exceeded.""" + # Arrange + doc_name = "compliance_report.pdf" + account_id = "account-123" + tenant_id = "tenant-456" + ip = "192.168.1.1" + device_info = "Mozilla/5.0" + + # Import the error class to properly catch it + from controllers.console.error import ComplianceRateLimitError + + # Mock the rate limiter to return True (rate limited) + with patch.object( + BillingService.compliance_download_rate_limiter, "is_rate_limited", return_value=True + ) as mock_is_limited: + # Act & Assert + with pytest.raises(ComplianceRateLimitError): + BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info) + + mock_is_limited.assert_called_once_with(f"{account_id}:{tenant_id}") + mock_send_request.assert_not_called() + + def test_education_verify_rate_limit_not_exceeded(self, mock_send_request): + """Test education verification when rate limit is not exceeded.""" + # Arrange + account_id = "account-123" + account_email = "student@university.edu" + expected_response = {"verified": True, "institution": "University"} + + # Mock the rate limiter to return False (not limited) + with ( + patch.object( + BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False + ) as mock_is_limited, + patch.object( + BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit" + ) as mock_increment, + ): + mock_send_request.return_value = expected_response + + # Act + result = BillingService.EducationIdentity.verify(account_id, account_email) + + # Assert + assert result == expected_response + mock_is_limited.assert_called_once_with(account_email) + mock_send_request.assert_called_once_with("GET", "/education/verify", params={"account_id": account_id}) + mock_increment.assert_called_once_with(account_email) + + def test_education_verify_rate_limit_exceeded(self, mock_send_request): + """Test education verification when rate limit is exceeded.""" + # Arrange + account_id = "account-123" + account_email = "student@university.edu" + + # Import the error class to properly catch it + from controllers.console.error import EducationVerifyLimitError + + # Mock the rate limiter to return True (rate limited) + with patch.object( + BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=True + ) as mock_is_limited: + # Act & Assert + with pytest.raises(EducationVerifyLimitError): + BillingService.EducationIdentity.verify(account_id, account_email) + + mock_is_limited.assert_called_once_with(account_email) + mock_send_request.assert_not_called() + + def test_education_activate_rate_limit_not_exceeded(self, mock_send_request): + """Test education activation when rate limit is not exceeded.""" + # Arrange + account = MagicMock(spec=Account) + account.id = "account-123" + account.email = "student@university.edu" + account.current_tenant_id = "tenant-456" + token = "verification-token" + institution = "MIT" + role = "student" + expected_response = {"result": "success", "activated": True} + + # Mock the rate limiter to return False (not limited) + with ( + patch.object( + BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=False + ) as mock_is_limited, + patch.object( + BillingService.EducationIdentity.activation_rate_limit, "increment_rate_limit" + ) as mock_increment, + ): + mock_send_request.return_value = expected_response + + # Act + result = BillingService.EducationIdentity.activate(account, token, institution, role) + + # Assert + assert result == expected_response + mock_is_limited.assert_called_once_with(account.email) + mock_send_request.assert_called_once_with( + "POST", + "/education/", + json={"institution": institution, "token": token, "role": role}, + params={"account_id": account.id, "curr_tenant_id": account.current_tenant_id}, + ) + mock_increment.assert_called_once_with(account.email) + + def test_education_activate_rate_limit_exceeded(self, mock_send_request): + """Test education activation when rate limit is exceeded.""" + # Arrange + account = MagicMock(spec=Account) + account.id = "account-123" + account.email = "student@university.edu" + account.current_tenant_id = "tenant-456" + token = "verification-token" + institution = "MIT" + role = "student" + + # Import the error class to properly catch it + from controllers.console.error import EducationActivateLimitError + + # Mock the rate limiter to return True (rate limited) + with patch.object( + BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=True + ) as mock_is_limited: + # Act & Assert + with pytest.raises(EducationActivateLimitError): + BillingService.EducationIdentity.activate(account, token, institution, role) + + mock_is_limited.assert_called_once_with(account.email) + mock_send_request.assert_not_called() + + +class TestBillingServiceEducationIdentity: + """Unit tests for education identity verification and management. + + Tests cover: + - Education verification status checking + - Institution autocomplete with pagination + - Default parameter handling + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_education_status(self, mock_send_request): + """Test checking education verification status.""" + # Arrange + account_id = "account-123" + expected_response = {"verified": True, "institution": "MIT", "role": "student"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.EducationIdentity.status(account_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("GET", "/education/status", params={"account_id": account_id}) + + def test_education_autocomplete(self, mock_send_request): + """Test education institution autocomplete.""" + # Arrange + keywords = "Massachusetts" + page = 0 + limit = 20 + expected_response = { + "institutions": [ + {"name": "Massachusetts Institute of Technology", "domain": "mit.edu"}, + {"name": "University of Massachusetts", "domain": "umass.edu"}, + ] + } + mock_send_request.return_value = expected_response + + # Act + result = BillingService.EducationIdentity.autocomplete(keywords, page, limit) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", "/education/autocomplete", params={"keywords": keywords, "page": page, "limit": limit} + ) + + def test_education_autocomplete_with_defaults(self, mock_send_request): + """Test education institution autocomplete with default parameters.""" + # Arrange + keywords = "Stanford" + expected_response = {"institutions": [{"name": "Stanford University", "domain": "stanford.edu"}]} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.EducationIdentity.autocomplete(keywords) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", "/education/autocomplete", params={"keywords": keywords, "page": 0, "limit": 20} + ) + + +class TestBillingServiceAccountManagement: + """Unit tests for account-related billing operations. + + Tests cover: + - Account deletion + - Email freeze status checking + - Account deletion feedback submission + - Tenant owner/admin permission validation + - Error handling for missing tenant joins + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + @pytest.fixture + def mock_db_session(self): + """Mock database session.""" + with patch("services.billing_service.db.session") as mock_session: + yield mock_session + + def test_delete_account(self, mock_send_request): + """Test account deletion.""" + # Arrange + account_id = "account-123" + expected_response = {"result": "success", "deleted": True} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.delete_account(account_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with("DELETE", "/account/", params={"account_id": account_id}) + + def test_is_email_in_freeze_true(self, mock_send_request): + """Test checking if email is frozen (returns True).""" + # Arrange + email = "frozen@example.com" + mock_send_request.return_value = {"data": True} + + # Act + result = BillingService.is_email_in_freeze(email) + + # Assert + assert result is True + mock_send_request.assert_called_once_with("GET", "/account/in-freeze", params={"email": email}) + + def test_is_email_in_freeze_false(self, mock_send_request): + """Test checking if email is frozen (returns False).""" + # Arrange + email = "active@example.com" + mock_send_request.return_value = {"data": False} + + # Act + result = BillingService.is_email_in_freeze(email) + + # Assert + assert result is False + mock_send_request.assert_called_once_with("GET", "/account/in-freeze", params={"email": email}) + + def test_is_email_in_freeze_exception_returns_false(self, mock_send_request): + """Test that is_email_in_freeze returns False on exception.""" + # Arrange + email = "error@example.com" + mock_send_request.side_effect = Exception("Network error") + + # Act + result = BillingService.is_email_in_freeze(email) + + # Assert + assert result is False + + def test_update_account_deletion_feedback(self, mock_send_request): + """Test updating account deletion feedback.""" + # Arrange + email = "user@example.com" + feedback = "Service was too expensive" + expected_response = {"result": "success"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.update_account_deletion_feedback(email, feedback) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "POST", "/account/delete-feedback", json={"email": email, "feedback": feedback} + ) + + def test_is_tenant_owner_or_admin_owner(self, mock_db_session): + """Test tenant owner/admin check for owner role.""" + # Arrange + current_user = MagicMock(spec=Account) + current_user.id = "account-123" + current_user.current_tenant_id = "tenant-456" + + mock_join = MagicMock(spec=TenantAccountJoin) + mock_join.role = TenantAccountRole.OWNER + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_join + mock_db_session.query.return_value = mock_query + + # Act - should not raise exception + BillingService.is_tenant_owner_or_admin(current_user) + + # Assert + mock_db_session.query.assert_called_once() + + def test_is_tenant_owner_or_admin_admin(self, mock_db_session): + """Test tenant owner/admin check for admin role.""" + # Arrange + current_user = MagicMock(spec=Account) + current_user.id = "account-123" + current_user.current_tenant_id = "tenant-456" + + mock_join = MagicMock(spec=TenantAccountJoin) + mock_join.role = TenantAccountRole.ADMIN + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_join + mock_db_session.query.return_value = mock_query + + # Act - should not raise exception + BillingService.is_tenant_owner_or_admin(current_user) + + # Assert + mock_db_session.query.assert_called_once() + + def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session): + """Test tenant owner/admin check raises error for normal user.""" + # Arrange + current_user = MagicMock(spec=Account) + current_user.id = "account-123" + current_user.current_tenant_id = "tenant-456" + + mock_join = MagicMock(spec=TenantAccountJoin) + mock_join.role = TenantAccountRole.NORMAL + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_join + mock_db_session.query.return_value = mock_query + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + BillingService.is_tenant_owner_or_admin(current_user) + assert "Only team owner or team admin can perform this action" in str(exc_info.value) + + def test_is_tenant_owner_or_admin_no_join_raises_error(self, mock_db_session): + """Test tenant owner/admin check raises error when join not found.""" + # Arrange + current_user = MagicMock(spec=Account) + current_user.id = "account-123" + current_user.current_tenant_id = "tenant-456" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = None + mock_db_session.query.return_value = mock_query + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + BillingService.is_tenant_owner_or_admin(current_user) + assert "Tenant account join not found" in str(exc_info.value) + + +class TestBillingServiceCacheManagement: + """Unit tests for billing cache management. + + Tests cover: + - Billing info cache invalidation + - Proper Redis key formatting + """ + + @pytest.fixture + def mock_redis_client(self): + """Mock Redis client.""" + with patch("services.billing_service.redis_client") as mock_redis: + yield mock_redis + + def test_clean_billing_info_cache(self, mock_redis_client): + """Test cleaning billing info cache.""" + # Arrange + tenant_id = "tenant-123" + expected_key = f"tenant:{tenant_id}:billing_info" + + # Act + BillingService.clean_billing_info_cache(tenant_id) + + # Assert + mock_redis_client.delete.assert_called_once_with(expected_key) + + +class TestBillingServicePartnerIntegration: + """Unit tests for partner integration features. + + Tests cover: + - Partner tenant binding synchronization + - Click ID tracking + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_sync_partner_tenants_bindings(self, mock_send_request): + """Test syncing partner tenant bindings.""" + # Arrange + account_id = "account-123" + partner_key = "partner-xyz" + click_id = "click-789" + expected_response = {"result": "success", "synced": True} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.sync_partner_tenants_bindings(account_id, partner_key, click_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "PUT", f"/partners/{partner_key}/tenants", json={"account_id": account_id, "click_id": click_id} + ) + + +class TestBillingServiceEdgeCases: + """Unit tests for edge cases and error scenarios. + + Tests cover: + - Empty responses from billing API + - Malformed JSON responses + - Boundary conditions for rate limits + - Multiple subscription tiers + - Zero and negative usage deltas + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_get_info_empty_response(self, mock_send_request): + """Test handling of empty billing info response.""" + # Arrange + tenant_id = "tenant-empty" + mock_send_request.return_value = {} + + # Act + result = BillingService.get_info(tenant_id) + + # Assert + assert result == {} + mock_send_request.assert_called_once() + + def test_update_tenant_feature_plan_usage_zero_delta(self, mock_send_request): + """Test updating tenant feature usage with zero delta (no change).""" + # Arrange + tenant_id = "tenant-123" + feature_key = "trigger" + delta = 0 # No change + expected_response = {"result": "success", "history_id": "hist-uuid-zero"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "POST", + "/tenant-feature-usage/usage", + params={"tenant_id": tenant_id, "feature_key": feature_key, "delta": delta}, + ) + + def test_update_tenant_feature_plan_usage_large_negative_delta(self, mock_send_request): + """Test updating tenant feature usage with large negative delta.""" + # Arrange + tenant_id = "tenant-456" + feature_key = "workflow" + delta = -1000 # Large consumption + expected_response = {"result": "success", "history_id": "hist-uuid-large"} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, delta) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once() + + def test_get_knowledge_rate_limit_all_subscription_tiers(self, mock_send_request): + """Test knowledge rate limit for all subscription tiers.""" + # Test SANDBOX tier + mock_send_request.return_value = {"limit": 10, "subscription_plan": CloudPlan.SANDBOX} + result = BillingService.get_knowledge_rate_limit("tenant-sandbox") + assert result["subscription_plan"] == CloudPlan.SANDBOX + assert result["limit"] == 10 + + # Test PROFESSIONAL tier + mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL} + result = BillingService.get_knowledge_rate_limit("tenant-pro") + assert result["subscription_plan"] == CloudPlan.PROFESSIONAL + assert result["limit"] == 100 + + # Test TEAM tier + mock_send_request.return_value = {"limit": 500, "subscription_plan": CloudPlan.TEAM} + result = BillingService.get_knowledge_rate_limit("tenant-team") + assert result["subscription_plan"] == CloudPlan.TEAM + assert result["limit"] == 500 + + def test_get_subscription_with_empty_optional_params(self, mock_send_request): + """Test subscription payment link with empty optional parameters.""" + # Arrange + plan = "professional" + interval = "yearly" + expected_response = {"payment_link": "https://payment.example.com/checkout"} + mock_send_request.return_value = expected_response + + # Act - empty email and tenant_id + result = BillingService.get_subscription(plan, interval, "", "") + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", + "/subscription/payment-link", + params={"plan": plan, "interval": interval, "prefilled_email": "", "tenant_id": ""}, + ) + + def test_get_invoices_with_empty_params(self, mock_send_request): + """Test invoice retrieval with empty parameters.""" + # Arrange + expected_response = {"invoices": []} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_invoices("", "") + + # Assert + assert result == expected_response + assert result["invoices"] == [] + + def test_refund_with_invalid_history_id_format(self, mock_send_request): + """Test refund with various history ID formats.""" + # Arrange - test with different ID formats + test_ids = ["hist-123", "uuid-abc-def", "12345", ""] + + for history_id in test_ids: + expected_response = {"result": "success", "history_id": history_id} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.refund_tenant_feature_plan_usage(history_id) + + # Assert + assert result["history_id"] == history_id + + def test_is_tenant_owner_or_admin_editor_role_raises_error(self): + """Test tenant owner/admin check raises error for editor role.""" + # Arrange + current_user = MagicMock(spec=Account) + current_user.id = "account-123" + current_user.current_tenant_id = "tenant-456" + + mock_join = MagicMock(spec=TenantAccountJoin) + mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged + + with patch("services.billing_service.db.session") as mock_session: + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_join + mock_session.query.return_value = mock_query + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + BillingService.is_tenant_owner_or_admin(current_user) + assert "Only team owner or team admin can perform this action" in str(exc_info.value) + + def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self): + """Test tenant owner/admin check raises error for dataset operator role.""" + # Arrange + current_user = MagicMock(spec=Account) + current_user.id = "account-123" + current_user.current_tenant_id = "tenant-456" + + mock_join = MagicMock(spec=TenantAccountJoin) + mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged + + with patch("services.billing_service.db.session") as mock_session: + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_join + mock_session.query.return_value = mock_query + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + BillingService.is_tenant_owner_or_admin(current_user) + assert "Only team owner or team admin can perform this action" in str(exc_info.value) + + +class TestBillingServiceIntegrationScenarios: + """Integration-style tests simulating real-world usage scenarios. + + These tests combine multiple service methods to test common workflows: + - Complete subscription upgrade flow + - Usage tracking and refund workflow + - Rate limit boundary testing + """ + + @pytest.fixture + def mock_send_request(self): + """Mock _send_request method.""" + with patch.object(BillingService, "_send_request") as mock: + yield mock + + def test_subscription_upgrade_workflow(self, mock_send_request): + """Test complete subscription upgrade workflow.""" + # Arrange + tenant_id = "tenant-upgrade" + + # Step 1: Get current billing info + mock_send_request.return_value = { + "subscription_plan": "sandbox", + "billing_cycle": "monthly", + "status": "active", + } + current_info = BillingService.get_info(tenant_id) + assert current_info["subscription_plan"] == "sandbox" + + # Step 2: Get payment link for upgrade + mock_send_request.return_value = {"payment_link": "https://payment.example.com/upgrade"} + payment_link = BillingService.get_subscription("professional", "monthly", "user@example.com", tenant_id) + assert "payment_link" in payment_link + + # Step 3: Verify new rate limits after upgrade + mock_send_request.return_value = {"limit": 100, "subscription_plan": CloudPlan.PROFESSIONAL} + rate_limit = BillingService.get_knowledge_rate_limit(tenant_id) + assert rate_limit["subscription_plan"] == CloudPlan.PROFESSIONAL + assert rate_limit["limit"] == 100 + + def test_usage_tracking_and_refund_workflow(self, mock_send_request): + """Test usage tracking with subsequent refund.""" + # Arrange + tenant_id = "tenant-usage" + feature_key = "workflow" + + # Step 1: Consume credits + mock_send_request.return_value = {"result": "success", "history_id": "hist-consume-123"} + consume_result = BillingService.update_tenant_feature_plan_usage(tenant_id, feature_key, -10) + history_id = consume_result["history_id"] + assert history_id == "hist-consume-123" + + # Step 2: Check current usage + mock_send_request.return_value = {"used": 10, "limit": 100, "remaining": 90} + usage = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key) + assert usage["used"] == 10 + assert usage["remaining"] == 90 + + # Step 3: Refund the usage + mock_send_request.return_value = {"result": "success", "history_id": history_id} + refund_result = BillingService.refund_tenant_feature_plan_usage(history_id) + assert refund_result["result"] == "success" + + # Step 4: Verify usage after refund + mock_send_request.return_value = {"used": 0, "limit": 100, "remaining": 100} + updated_usage = BillingService.get_tenant_feature_plan_usage(tenant_id, feature_key) + assert updated_usage["used"] == 0 + assert updated_usage["remaining"] == 100 + + def test_compliance_download_multiple_requests_within_limit(self, mock_send_request): + """Test multiple compliance downloads within rate limit.""" + # Arrange + account_id = "account-compliance" + tenant_id = "tenant-compliance" + doc_name = "compliance_report.pdf" + ip = "192.168.1.1" + device_info = "Mozilla/5.0" + + # Mock rate limiter to allow 3 requests (under limit of 4) + with ( + patch.object( + BillingService.compliance_download_rate_limiter, "is_rate_limited", side_effect=[False, False, False] + ) as mock_is_limited, + patch.object(BillingService.compliance_download_rate_limiter, "increment_rate_limit") as mock_increment, + ): + mock_send_request.return_value = {"download_link": "https://example.com/download"} + + # Act - Make 3 requests + for i in range(3): + result = BillingService.get_compliance_download_link(doc_name, account_id, tenant_id, ip, device_info) + assert "download_link" in result + + # Assert - All 3 requests succeeded + assert mock_is_limited.call_count == 3 + assert mock_increment.call_count == 3 + + def test_education_verification_and_activation_flow(self, mock_send_request): + """Test complete education verification and activation flow.""" + # Arrange + account = MagicMock(spec=Account) + account.id = "account-edu" + account.email = "student@mit.edu" + account.current_tenant_id = "tenant-edu" + + # Step 1: Search for institution + with ( + patch.object( + BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False + ), + patch.object(BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"), + ): + mock_send_request.return_value = { + "institutions": [{"name": "Massachusetts Institute of Technology", "domain": "mit.edu"}] + } + institutions = BillingService.EducationIdentity.autocomplete("MIT") + assert len(institutions["institutions"]) > 0 + + # Step 2: Verify email + with ( + patch.object( + BillingService.EducationIdentity.verification_rate_limit, "is_rate_limited", return_value=False + ), + patch.object(BillingService.EducationIdentity.verification_rate_limit, "increment_rate_limit"), + ): + mock_send_request.return_value = {"verified": True, "institution": "MIT"} + verify_result = BillingService.EducationIdentity.verify(account.id, account.email) + assert verify_result["verified"] is True + + # Step 3: Check status + mock_send_request.return_value = {"verified": True, "institution": "MIT", "role": "student"} + status = BillingService.EducationIdentity.status(account.id) + assert status["verified"] is True + + # Step 4: Activate education benefits + with ( + patch.object(BillingService.EducationIdentity.activation_rate_limit, "is_rate_limited", return_value=False), + patch.object(BillingService.EducationIdentity.activation_rate_limit, "increment_rate_limit"), + ): + mock_send_request.return_value = {"result": "success", "activated": True} + activate_result = BillingService.EducationIdentity.activate(account, "token-123", "MIT", "student") + assert activate_result["activated"] is True