From 212b819f1c7fad31420200a471658500ca8f9329 Mon Sep 17 00:00:00 2001 From: Escape0707 Date: Wed, 10 Jun 2026 10:55:50 +0900 Subject: [PATCH] test: migrate credit pool service tests to Testcontainers (#37252) --- .../.ruff.toml | 111 ++++++++- .../conftest.py | 1 + .../pyrefly.toml | 1 - .../plugin/test_plugin_permission_service.py | 4 +- .../services/test_credit_pool_service.py | 210 ++++++++++++++---- .../services/test_trigger_provider_service.py | 33 +-- .../services/test_credit_pool_service.py | 178 --------------- 7 files changed, 284 insertions(+), 254 deletions(-) delete mode 100644 api/tests/unit_tests/services/test_credit_pool_service.py diff --git a/api/tests/test_containers_integration_tests/.ruff.toml b/api/tests/test_containers_integration_tests/.ruff.toml index be0109f462..390eb14851 100644 --- a/api/tests/test_containers_integration_tests/.ruff.toml +++ b/api/tests/test_containers_integration_tests/.ruff.toml @@ -2,18 +2,115 @@ extend = "../../.ruff.toml" src = ["../.."] [lint] -extend-select = ["ANN401", "TID251"] +extend-select = ["ANN401", "ARG", "TID251"] [lint.per-file-ignores] -"**/*.py" = ["S110", "T201"] -"core/rag/pipeline/test_queue_integration.py" = ["ANN401", "TID251"] +"core/rag/pipeline/test_queue_integration.py" = ["ANN401", "TID251", "ARG"] "models/test_types_enum_text.py" = ["ANN401", "TID251"] -"services/test_app_dsl_service.py" = ["ANN401", "TID251"] -"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251"] +"services/test_app_dsl_service.py" = ["ANN401", "TID251", "ARG"] +"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251", "ARG"] "services/test_hit_testing_service.py" = ["ANN401", "TID251"] -"services/test_recommended_app_service.py" = ["ANN401", "TID251"] +"services/test_recommended_app_service.py" = ["ANN401", "TID251", "ARG"] "trigger/conftest.py" = ["ANN401", "TID251"] -"trigger/test_trigger_e2e.py" = ["ANN401", "TID251"] +"trigger/test_trigger_e2e.py" = ["ANN401", "TID251", "ARG"] +"controllers/console/app/test_app_apis.py" = ["ARG"] +"controllers/console/app/test_app_import_api.py" = ["ARG"] +"controllers/console/auth/test_oauth.py" = ["ARG"] +"controllers/console/auth/test_password_reset.py" = ["ARG"] +"controllers/console/datasets/test_data_source.py" = ["ARG"] +"controllers/console/test_apikey.py" = ["ARG"] +"controllers/console/workspace/test_tool_provider.py" = ["ARG"] +"controllers/mcp/test_mcp.py" = ["ARG"] +"controllers/service_api/dataset/test_dataset.py" = ["ARG"] +"controllers/web/test_conversation.py" = ["ARG"] +"controllers/web/test_human_input_form.py" = ["ARG"] +"controllers/web/test_wraps.py" = ["ARG"] +"core/app/layers/test_pause_state_persist_layer.py" = ["ARG"] +"core/rag/retrieval/test_dataset_retrieval_integration.py" = ["ARG"] +"models/test_account.py" = ["ARG"] +"models/test_conversation_message_inputs.py" = ["ARG"] +"models/test_conversation_status_count.py" = ["ARG"] +"repositories/test_sqlalchemy_api_workflow_run_repository.py" = ["ARG"] +"repositories/test_workflow_run_repository.py" = ["ARG"] +"services/auth/test_api_key_auth_service.py" = ["ARG"] +"services/auth/test_auth_integration.py" = ["ARG"] +"services/dataset_collection_binding.py" = ["ARG"] +"services/dataset_service_update_delete.py" = ["ARG"] +"services/document_service_status.py" = ["ARG"] +"services/enterprise/test_account_deletion_sync.py" = ["ARG"] +"services/plugin/test_plugin_parameter_service.py" = ["ARG"] +"services/plugin/test_plugin_service.py" = ["ARG"] +"services/rag_pipeline/test_rag_pipeline_service_db.py" = ["ARG"] +"services/recommend_app/test_database_retrieval.py" = ["ARG"] +"services/test_account_service.py" = ["ARG"] +"services/test_advanced_prompt_template_service.py" = ["ARG"] +"services/test_annotation_service.py" = ["ARG"] +"services/test_api_based_extension_service.py" = ["ARG"] +"services/test_api_token_service.py" = ["ARG"] +"services/test_app_generate_service.py" = ["ARG"] +"services/test_app_service.py" = ["ARG"] +"services/test_attachment_service.py" = ["ARG"] +"services/test_conversation_variable_updater.py" = ["ARG"] +"services/test_dataset_permission_service.py" = ["ARG"] +"services/test_dataset_service_batch_update_document_status.py" = ["ARG"] +"services/test_dataset_service_retrieval.py" = ["ARG"] +"services/test_delete_archived_workflow_run.py" = ["ARG"] +"services/test_document_service_rename_document.py" = ["ARG"] +"services/test_end_user_service.py" = ["ARG"] +"services/test_feature_service.py" = ["ARG"] +"services/test_feedback_service.py" = ["ARG"] +"services/test_file_service.py" = ["ARG"] +"services/test_human_input_delivery_test_service.py" = ["ARG"] +"services/test_message_service.py" = ["ARG"] +"services/test_messages_clean_service.py" = ["ARG", "S110"] +"services/test_metadata_partial_update.py" = ["ARG"] +"services/test_metadata_service.py" = ["ARG"] +"services/test_model_load_balancing_service.py" = ["ARG"] +"services/test_model_provider_service.py" = ["ARG"] +"services/test_oauth_server_service.py" = ["ARG"] +"services/test_ops_service.py" = ["ARG"] +"services/test_saved_message_service.py" = ["ARG"] +"services/test_web_conversation_service.py" = ["ARG"] +"services/test_webapp_auth_service.py" = ["ARG"] +"services/test_webhook_service.py" = ["ARG"] +"services/test_workflow_app_service.py" = ["ARG"] +"services/test_workflow_draft_variable_service.py" = ["ARG"] +"services/test_workflow_run_service.py" = ["ARG"] +"services/test_workflow_service.py" = ["ARG"] +"services/test_workspace_service.py" = ["ARG"] +"services/tools/test_api_tools_manage_service.py" = ["ARG"] +"services/tools/test_mcp_tools_manage_service.py" = ["ARG"] +"services/tools/test_tools_transform_service.py" = ["ARG"] +"services/workflow/test_workflow_converter.py" = ["ARG"] +"tasks/test_add_document_to_index_task.py" = ["ARG"] +"tasks/test_batch_clean_document_task.py" = ["ARG"] +"tasks/test_batch_create_segment_to_index_task.py" = ["ARG"] +"tasks/test_clean_dataset_task.py" = ["T201"] +"tasks/test_clean_notion_document_task.py" = ["ARG"] +"tasks/test_create_segment_to_index_task.py" = ["ARG"] +"tasks/test_dataset_indexing_task.py" = ["ARG"] +"tasks/test_deal_dataset_vector_index_task.py" = ["ARG"] +"tasks/test_delete_segment_from_index_task.py" = ["ARG"] +"tasks/test_disable_segment_from_index_task.py" = ["ARG"] +"tasks/test_disable_segments_from_index_task.py" = ["ARG"] +"tasks/test_document_indexing_sync_task.py" = ["ARG"] +"tasks/test_document_indexing_task.py" = ["ARG"] +"tasks/test_document_indexing_update_task.py" = ["ARG"] +"tasks/test_duplicate_document_indexing_task.py" = ["ARG"] +"tasks/test_enable_segments_to_index_task.py" = ["ARG"] +"tasks/test_mail_change_mail_task.py" = ["ARG"] +"tasks/test_mail_email_code_login_task.py" = ["ARG"] +"tasks/test_mail_human_input_delivery_task.py" = ["ARG"] +"tasks/test_mail_inner_task.py" = ["ARG"] +"tasks/test_mail_invite_member_task.py" = ["ARG"] +"tasks/test_mail_owner_transfer_task.py" = ["ARG"] +"tasks/test_mail_register_task.py" = ["ARG"] +"tasks/test_rag_pipeline_run_tasks.py" = ["ARG"] +"test_workflow_pause_integration.py" = ["T201"] +"workflow/nodes/code_executor/test_code_javascript.py" = ["ARG"] +"workflow/nodes/code_executor/test_code_jinja2.py" = ["ARG"] +"workflow/nodes/code_executor/test_code_python3.py" = ["ARG"] +"workflow/nodes/code_executor/test_utils.py" = ["T201"] [lint.flake8-tidy-imports.banned-api."typing.Any"] msg = "Use object, Protocol, TypedDict, TypeVar, ParamSpec, or a localized cast instead." diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 2ee9ae68b2..8099756f66 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -424,6 +424,7 @@ def flask_app_with_containers(set_up_containers_and_env: DifyTestContainers) -> Returns: Flask: Configured Flask application """ + assert set_up_containers_and_env is _container_manager logger.info("=== Creating session-scoped Flask application ===") app = _create_app_with_containers() logger.info("Session-scoped Flask application created successfully") diff --git a/api/tests/test_containers_integration_tests/pyrefly.toml b/api/tests/test_containers_integration_tests/pyrefly.toml index e8100ee3c9..36d83da43e 100644 --- a/api/tests/test_containers_integration_tests/pyrefly.toml +++ b/api/tests/test_containers_integration_tests/pyrefly.toml @@ -83,7 +83,6 @@ project-excludes = [ "services/test_conversation_service.py", "services/test_conversation_service_variables.py", "services/test_conversation_variable_updater.py", - "services/test_credit_pool_service.py", "services/test_dataset_permission_service.py", "services/test_dataset_service.py", "services/test_dataset_service_batch_update_document_status.py", diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py index dfa3bc9f01..e7cf04c0d8 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py @@ -2,6 +2,7 @@ from __future__ import annotations from uuid import uuid4 +import pytest from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -45,7 +46,8 @@ class TestGetPermission: assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE - def test_returns_none_when_not_found(self, db_session_with_containers: Session) -> None: + @pytest.mark.usefixtures("flask_app_with_containers") + def test_returns_none_when_not_found(self) -> None: result = PluginPermissionService.get_permission(_tenant_id()) assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 07dc3a4e9e..de8e6ba612 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -1,10 +1,13 @@ """Testcontainers integration tests for CreditPoolService.""" +from unittest.mock import patch from uuid import uuid4 import pytest +from flask import has_app_context from sqlalchemy.orm import Session +from core.db.session_factory import session_factory from core.errors.error import QuotaExceededError from models import TenantCreditPool from models.enums import ProviderQuotaType @@ -15,7 +18,25 @@ class TestCreditPoolService: def _create_tenant_id(self) -> str: return str(uuid4()) - def test_create_default_pool(self, db_session_with_containers: Session): + def _create_pool( + self, + db_session: Session, + *, + tenant_id: str, + quota_limit: int = 10, + quota_used: int = 0, + ) -> None: + pool = TenantCreditPool( + tenant_id=tenant_id, + pool_type=ProviderQuotaType.TRIAL, + quota_limit=quota_limit, + quota_used=quota_used, + ) + db_session.add(pool) + db_session.commit() + + @pytest.mark.usefixtures("db_session_with_containers") + def test_create_default_pool(self) -> None: tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) @@ -26,9 +47,9 @@ class TestCreditPoolService: assert pool.quota_used == 0 assert pool.quota_limit > 0 - def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session): + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session) -> None: tenant_id = self._create_tenant_id() - CreditPoolService.create_default_pool(tenant_id) + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=0) result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) @@ -36,88 +57,191 @@ class TestCreditPoolService: assert result.tenant_id == tenant_id assert result.pool_type == ProviderQuotaType.TRIAL - def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session): + @pytest.mark.usefixtures("flask_app_with_containers") + def test_get_pool_uses_configured_session_factory_without_flask_app_context(self) -> None: + tenant_id = self._create_tenant_id() + session_maker = session_factory.get_session_maker() + with session_maker.begin() as session: + session.add( + TenantCreditPool( + tenant_id=tenant_id, + pool_type=ProviderQuotaType.TRIAL, + quota_limit=10, + quota_used=2, + ) + ) + + assert not has_app_context() + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) + + assert result is not None + assert result.tenant_id == tenant_id + assert result.pool_type == ProviderQuotaType.TRIAL + assert result.quota_used == 2 + + @pytest.mark.usefixtures("flask_app_with_containers") + def test_get_pool_returns_none_when_not_exists(self) -> None: result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None - def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session): + @pytest.mark.usefixtures("flask_app_with_containers") + def test_check_credits_available_returns_false_when_no_pool(self) -> None: result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) assert result is False - def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session): + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session) -> None: tenant_id = self._create_tenant_id() - CreditPoolService.create_default_pool(tenant_id) + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=0) result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10) assert result is True - def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session): + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session) -> None: tenant_id = self._create_tenant_id() - pool = CreditPoolService.create_default_pool(tenant_id) - # Exhaust credits - pool.quota_used = pool.quota_limit - db_session_with_containers.commit() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10) result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1) assert result is False - def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session): + @pytest.mark.usefixtures("flask_app_with_containers") + def test_check_and_deduct_credits_raises_when_no_pool(self) -> None: with pytest.raises(QuotaExceededError, match="Credit pool not found"): - CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) + CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=1) - def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session): + def test_check_and_deduct_credits_returns_zero_for_non_positive_request( + self, db_session_with_containers: Session + ) -> None: tenant_id = self._create_tenant_id() - pool = CreditPoolService.create_default_pool(tenant_id) - pool.quota_used = pool.quota_limit - db_session_with_containers.commit() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2) + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=0) + + assert result == 0 + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool is not None + assert updated_pool.quota_used == 2 + + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session) -> None: + tenant_id = self._create_tenant_id() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10) with pytest.raises(QuotaExceededError, match="No credits remaining"): - CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1) - def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session): + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool is not None + assert updated_pool.quota_used == 10 + + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session) -> None: tenant_id = self._create_tenant_id() - CreditPoolService.create_default_pool(tenant_id) - credits_required = 10 + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2) + credits_required = 3 result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required) assert result == credits_required - db_session_with_containers.expire_all() pool = CreditPoolService.get_pool(tenant_id=tenant_id) - assert pool.quota_used == credits_required + assert pool is not None + assert pool.quota_used == 5 def test_check_and_deduct_credits_raises_without_deducting_when_insufficient( self, db_session_with_containers: Session - ): + ) -> None: tenant_id = self._create_tenant_id() - pool = CreditPoolService.create_default_pool(tenant_id) - remaining = 5 - pool.quota_used = pool.quota_limit - remaining - quota_used = pool.quota_used - db_session_with_containers.commit() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=9) with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"): - CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) - db_session_with_containers.expire_all() updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) - assert updated_pool.quota_used == quota_used + assert updated_pool is not None + assert updated_pool.quota_used == 9 - def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session): + def test_check_and_deduct_credits_wraps_unexpected_deduction_errors( + self, db_session_with_containers: Session + ) -> None: tenant_id = self._create_tenant_id() - pool = CreditPoolService.create_default_pool(tenant_id) - remaining = 5 - pool.quota_used = pool.quota_limit - remaining - quota_limit = pool.quota_limit - db_session_with_containers.commit() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2) - result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200) + with ( + patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), + pytest.raises(QuotaExceededError, match="Failed to deduct credits"), + ): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1) - assert result == remaining - db_session_with_containers.expire_all() updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) - assert updated_pool.quota_used == quota_limit + assert updated_pool is not None + assert updated_pool.quota_used == 2 + + def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session) -> None: + tenant_id = self._create_tenant_id() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=9) + + result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3) + + assert result == 1 + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool is not None + assert updated_pool.quota_used == 10 + + def test_deduct_credits_capped_returns_zero_for_non_positive_request( + self, db_session_with_containers: Session + ) -> None: + tenant_id = self._create_tenant_id() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2) + + result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=0) + + assert result == 0 + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool is not None + assert updated_pool.quota_used == 2 + + @pytest.mark.usefixtures("flask_app_with_containers") + def test_deduct_credits_capped_returns_zero_when_no_pool(self) -> None: + result = CreditPoolService.deduct_credits_capped(tenant_id=self._create_tenant_id(), credits_required=1) + + assert result == 0 + + def test_deduct_credits_capped_returns_zero_when_pool_is_empty(self, db_session_with_containers: Session) -> None: + tenant_id = self._create_tenant_id() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10) + + result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) + + assert result == 0 + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool is not None + assert updated_pool.quota_used == 10 + + def test_deduct_credits_capped_wraps_unexpected_deduction_errors(self, db_session_with_containers: Session) -> None: + tenant_id = self._create_tenant_id() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2) + + with ( + patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), + pytest.raises(QuotaExceededError, match="Failed to deduct credits"), + ): + CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) + + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool is not None + assert updated_pool.quota_used == 2 + + def test_deduct_credits_capped_reraises_quota_exceeded_errors(self, db_session_with_containers: Session) -> None: + tenant_id = self._create_tenant_id() + self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2) + + with ( + patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")), + pytest.raises(QuotaExceededError, match="quota unavailable"), + ): + CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) + + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool is not None + assert updated_pool.quota_used == 2 diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 0aea7151e9..3a6d635e63 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -54,14 +54,12 @@ class TestTriggerProviderService: def _create_test_account_and_tenant( self, - db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies, ) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. Args: - db_session_with_containers: Database session from testcontainers infrastructure mock_external_service_dependencies: Mock dependencies Returns: @@ -166,9 +164,7 @@ class TestTriggerProviderService: - Database state is correctly updated """ fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) + account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -266,9 +262,7 @@ class TestTriggerProviderService: - Merged credentials contain only new values """ fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) + account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -326,9 +320,7 @@ class TestTriggerProviderService: - Original credentials are preserved """ fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) + account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -384,9 +376,7 @@ class TestTriggerProviderService: - UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials """ fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) + account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -444,9 +434,7 @@ class TestTriggerProviderService: - Original subscription state is preserved """ fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) + account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -486,8 +474,9 @@ class TestTriggerProviderService: assert subscription.name == original_name assert subscription.parameters == original_parameters + @pytest.mark.usefixtures("db_session_with_containers") def test_rebuild_trigger_subscription_subscription_not_found( - self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + self, mock_external_service_dependencies: MockExternalServiceDependencies ) -> None: """ Test error when subscription is not found. @@ -496,9 +485,7 @@ class TestTriggerProviderService: - Proper error is raised when subscription doesn't exist """ fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) + account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") fake_subscription_id = fake.uuid4() @@ -522,9 +509,7 @@ class TestTriggerProviderService: - Error is raised when new name conflicts with existing subscription """ fake = Faker() - account, tenant = self._create_test_account_and_tenant( - db_session_with_containers, mock_external_service_dependencies - ) + account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py deleted file mode 100644 index 6956dbbd6e..0000000000 --- a/api/tests/unit_tests/services/test_credit_pool_service.py +++ /dev/null @@ -1,178 +0,0 @@ -from collections.abc import Generator -from contextlib import contextmanager -from unittest.mock import patch -from uuid import uuid4 - -import pytest -from sqlalchemy import create_engine, select -from sqlalchemy.engine import Engine -from sqlalchemy.orm import sessionmaker - -from core.errors.error import QuotaExceededError -from models import TenantCreditPool -from models.enums import ProviderQuotaType -from services.credit_pool_service import CreditPoolService - - -def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]: - engine = create_engine("sqlite:///:memory:") - TenantCreditPool.__table__.create(engine) - tenant_id = str(uuid4()) - pool_id = str(uuid4()) - with engine.begin() as connection: - connection.execute( - TenantCreditPool.__table__.insert(), - { - "id": pool_id, - "tenant_id": tenant_id, - "pool_type": ProviderQuotaType.TRIAL, - "quota_limit": quota_limit, - "quota_used": quota_used, - }, - ) - return engine, tenant_id, pool_id - - -@contextmanager -def _patched_session_factory(engine: Engine) -> Generator[None, None, None]: - session_maker = sessionmaker(bind=engine, expire_on_commit=False) - with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker): - yield - - -def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None: - with engine.connect() as connection: - return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) - - -def test_get_pool_uses_configured_session_factory_without_flask_app_context() -> None: - engine, tenant_id, _ = _create_engine_with_pool(quota_limit=10, quota_used=2) - - with _patched_session_factory(engine): - pool = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL) - - assert pool is not None - assert pool.tenant_id == tenant_id - assert pool.quota_used == 2 - - -def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) - - with _patched_session_factory(engine): - deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) - - assert deducted_credits == 3 - assert _get_quota_used(engine=engine, pool_id=pool_id) == 5 - - -def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None: - assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0 - - -def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None: - engine = create_engine("sqlite:///:memory:") - TenantCreditPool.__table__.create(engine) - - with ( - _patched_session_factory(engine), - pytest.raises(QuotaExceededError, match="Credit pool not found"), - ): - CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1) - - -def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10) - - with ( - _patched_session_factory(engine), - pytest.raises(QuotaExceededError, match="No credits remaining"), - ): - CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1) - - assert _get_quota_used(engine=engine, pool_id=pool_id) == 10 - - -def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) - - with ( - _patched_session_factory(engine), - pytest.raises(QuotaExceededError, match="Insufficient credits remaining"), - ): - CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) - - assert _get_quota_used(engine=engine, pool_id=pool_id) == 9 - - -def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) - - with ( - _patched_session_factory(engine), - patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), - pytest.raises(QuotaExceededError, match="Failed to deduct credits"), - ): - CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1) - - assert _get_quota_used(engine=engine, pool_id=pool_id) == 2 - - -def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None: - assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0 - - -def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None: - engine = create_engine("sqlite:///:memory:") - TenantCreditPool.__table__.create(engine) - - with _patched_session_factory(engine): - deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1) - - assert deducted_credits == 0 - - -def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10) - - with _patched_session_factory(engine): - deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) - - assert deducted_credits == 0 - assert _get_quota_used(engine=engine, pool_id=pool_id) == 10 - - -def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) - - with _patched_session_factory(engine): - deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3) - - assert deducted_credits == 1 - assert _get_quota_used(engine=engine, pool_id=pool_id) == 10 - - -def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) - - with ( - _patched_session_factory(engine), - patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), - pytest.raises(QuotaExceededError, match="Failed to deduct credits"), - ): - CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) - - assert _get_quota_used(engine=engine, pool_id=pool_id) == 2 - - -def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None: - engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) - - with ( - _patched_session_factory(engine), - patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")), - pytest.raises(QuotaExceededError, match="quota unavailable"), - ): - CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) - - assert _get_quota_used(engine=engine, pool_id=pool_id) == 2