test: migrate credit pool service tests to Testcontainers (#37252)

This commit is contained in:
Escape0707 2026-06-10 10:55:50 +09:00 committed by GitHub
parent 3fb1d3055e
commit 212b819f1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 284 additions and 254 deletions

View File

@ -2,18 +2,115 @@ extend = "../../.ruff.toml"
src = ["../.."]
[lint]
extend-select = ["ANN401", "TID251"]
extend-select = ["ANN401", "ARG", "TID251"]
[lint.per-file-ignores]
"**/*.py" = ["S110", "T201"]
"core/rag/pipeline/test_queue_integration.py" = ["ANN401", "TID251"]
"core/rag/pipeline/test_queue_integration.py" = ["ANN401", "TID251", "ARG"]
"models/test_types_enum_text.py" = ["ANN401", "TID251"]
"services/test_app_dsl_service.py" = ["ANN401", "TID251"]
"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251"]
"services/test_app_dsl_service.py" = ["ANN401", "TID251", "ARG"]
"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251", "ARG"]
"services/test_hit_testing_service.py" = ["ANN401", "TID251"]
"services/test_recommended_app_service.py" = ["ANN401", "TID251"]
"services/test_recommended_app_service.py" = ["ANN401", "TID251", "ARG"]
"trigger/conftest.py" = ["ANN401", "TID251"]
"trigger/test_trigger_e2e.py" = ["ANN401", "TID251"]
"trigger/test_trigger_e2e.py" = ["ANN401", "TID251", "ARG"]
"controllers/console/app/test_app_apis.py" = ["ARG"]
"controllers/console/app/test_app_import_api.py" = ["ARG"]
"controllers/console/auth/test_oauth.py" = ["ARG"]
"controllers/console/auth/test_password_reset.py" = ["ARG"]
"controllers/console/datasets/test_data_source.py" = ["ARG"]
"controllers/console/test_apikey.py" = ["ARG"]
"controllers/console/workspace/test_tool_provider.py" = ["ARG"]
"controllers/mcp/test_mcp.py" = ["ARG"]
"controllers/service_api/dataset/test_dataset.py" = ["ARG"]
"controllers/web/test_conversation.py" = ["ARG"]
"controllers/web/test_human_input_form.py" = ["ARG"]
"controllers/web/test_wraps.py" = ["ARG"]
"core/app/layers/test_pause_state_persist_layer.py" = ["ARG"]
"core/rag/retrieval/test_dataset_retrieval_integration.py" = ["ARG"]
"models/test_account.py" = ["ARG"]
"models/test_conversation_message_inputs.py" = ["ARG"]
"models/test_conversation_status_count.py" = ["ARG"]
"repositories/test_sqlalchemy_api_workflow_run_repository.py" = ["ARG"]
"repositories/test_workflow_run_repository.py" = ["ARG"]
"services/auth/test_api_key_auth_service.py" = ["ARG"]
"services/auth/test_auth_integration.py" = ["ARG"]
"services/dataset_collection_binding.py" = ["ARG"]
"services/dataset_service_update_delete.py" = ["ARG"]
"services/document_service_status.py" = ["ARG"]
"services/enterprise/test_account_deletion_sync.py" = ["ARG"]
"services/plugin/test_plugin_parameter_service.py" = ["ARG"]
"services/plugin/test_plugin_service.py" = ["ARG"]
"services/rag_pipeline/test_rag_pipeline_service_db.py" = ["ARG"]
"services/recommend_app/test_database_retrieval.py" = ["ARG"]
"services/test_account_service.py" = ["ARG"]
"services/test_advanced_prompt_template_service.py" = ["ARG"]
"services/test_annotation_service.py" = ["ARG"]
"services/test_api_based_extension_service.py" = ["ARG"]
"services/test_api_token_service.py" = ["ARG"]
"services/test_app_generate_service.py" = ["ARG"]
"services/test_app_service.py" = ["ARG"]
"services/test_attachment_service.py" = ["ARG"]
"services/test_conversation_variable_updater.py" = ["ARG"]
"services/test_dataset_permission_service.py" = ["ARG"]
"services/test_dataset_service_batch_update_document_status.py" = ["ARG"]
"services/test_dataset_service_retrieval.py" = ["ARG"]
"services/test_delete_archived_workflow_run.py" = ["ARG"]
"services/test_document_service_rename_document.py" = ["ARG"]
"services/test_end_user_service.py" = ["ARG"]
"services/test_feature_service.py" = ["ARG"]
"services/test_feedback_service.py" = ["ARG"]
"services/test_file_service.py" = ["ARG"]
"services/test_human_input_delivery_test_service.py" = ["ARG"]
"services/test_message_service.py" = ["ARG"]
"services/test_messages_clean_service.py" = ["ARG", "S110"]
"services/test_metadata_partial_update.py" = ["ARG"]
"services/test_metadata_service.py" = ["ARG"]
"services/test_model_load_balancing_service.py" = ["ARG"]
"services/test_model_provider_service.py" = ["ARG"]
"services/test_oauth_server_service.py" = ["ARG"]
"services/test_ops_service.py" = ["ARG"]
"services/test_saved_message_service.py" = ["ARG"]
"services/test_web_conversation_service.py" = ["ARG"]
"services/test_webapp_auth_service.py" = ["ARG"]
"services/test_webhook_service.py" = ["ARG"]
"services/test_workflow_app_service.py" = ["ARG"]
"services/test_workflow_draft_variable_service.py" = ["ARG"]
"services/test_workflow_run_service.py" = ["ARG"]
"services/test_workflow_service.py" = ["ARG"]
"services/test_workspace_service.py" = ["ARG"]
"services/tools/test_api_tools_manage_service.py" = ["ARG"]
"services/tools/test_mcp_tools_manage_service.py" = ["ARG"]
"services/tools/test_tools_transform_service.py" = ["ARG"]
"services/workflow/test_workflow_converter.py" = ["ARG"]
"tasks/test_add_document_to_index_task.py" = ["ARG"]
"tasks/test_batch_clean_document_task.py" = ["ARG"]
"tasks/test_batch_create_segment_to_index_task.py" = ["ARG"]
"tasks/test_clean_dataset_task.py" = ["T201"]
"tasks/test_clean_notion_document_task.py" = ["ARG"]
"tasks/test_create_segment_to_index_task.py" = ["ARG"]
"tasks/test_dataset_indexing_task.py" = ["ARG"]
"tasks/test_deal_dataset_vector_index_task.py" = ["ARG"]
"tasks/test_delete_segment_from_index_task.py" = ["ARG"]
"tasks/test_disable_segment_from_index_task.py" = ["ARG"]
"tasks/test_disable_segments_from_index_task.py" = ["ARG"]
"tasks/test_document_indexing_sync_task.py" = ["ARG"]
"tasks/test_document_indexing_task.py" = ["ARG"]
"tasks/test_document_indexing_update_task.py" = ["ARG"]
"tasks/test_duplicate_document_indexing_task.py" = ["ARG"]
"tasks/test_enable_segments_to_index_task.py" = ["ARG"]
"tasks/test_mail_change_mail_task.py" = ["ARG"]
"tasks/test_mail_email_code_login_task.py" = ["ARG"]
"tasks/test_mail_human_input_delivery_task.py" = ["ARG"]
"tasks/test_mail_inner_task.py" = ["ARG"]
"tasks/test_mail_invite_member_task.py" = ["ARG"]
"tasks/test_mail_owner_transfer_task.py" = ["ARG"]
"tasks/test_mail_register_task.py" = ["ARG"]
"tasks/test_rag_pipeline_run_tasks.py" = ["ARG"]
"test_workflow_pause_integration.py" = ["T201"]
"workflow/nodes/code_executor/test_code_javascript.py" = ["ARG"]
"workflow/nodes/code_executor/test_code_jinja2.py" = ["ARG"]
"workflow/nodes/code_executor/test_code_python3.py" = ["ARG"]
"workflow/nodes/code_executor/test_utils.py" = ["T201"]
[lint.flake8-tidy-imports.banned-api."typing.Any"]
msg = "Use object, Protocol, TypedDict, TypeVar, ParamSpec, or a localized cast instead."

View File

@ -424,6 +424,7 @@ def flask_app_with_containers(set_up_containers_and_env: DifyTestContainers) ->
Returns:
Flask: Configured Flask application
"""
assert set_up_containers_and_env is _container_manager
logger.info("=== Creating session-scoped Flask application ===")
app = _create_app_with_containers()
logger.info("Session-scoped Flask application created successfully")

View File

@ -83,7 +83,6 @@ project-excludes = [
"services/test_conversation_service.py",
"services/test_conversation_service_variables.py",
"services/test_conversation_variable_updater.py",
"services/test_credit_pool_service.py",
"services/test_dataset_permission_service.py",
"services/test_dataset_service.py",
"services/test_dataset_service_batch_update_document_status.py",

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from uuid import uuid4
import pytest
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@ -45,7 +46,8 @@ class TestGetPermission:
assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_returns_none_when_not_found(self, db_session_with_containers: Session) -> None:
@pytest.mark.usefixtures("flask_app_with_containers")
def test_returns_none_when_not_found(self) -> None:
result = PluginPermissionService.get_permission(_tenant_id())
assert result is None

View File

@ -1,10 +1,13 @@
"""Testcontainers integration tests for CreditPoolService."""
from unittest.mock import patch
from uuid import uuid4
import pytest
from flask import has_app_context
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
@ -15,7 +18,25 @@ class TestCreditPoolService:
def _create_tenant_id(self) -> str:
return str(uuid4())
def test_create_default_pool(self, db_session_with_containers: Session):
def _create_pool(
self,
db_session: Session,
*,
tenant_id: str,
quota_limit: int = 10,
quota_used: int = 0,
) -> None:
pool = TenantCreditPool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL,
quota_limit=quota_limit,
quota_used=quota_used,
)
db_session.add(pool)
db_session.commit()
@pytest.mark.usefixtures("db_session_with_containers")
def test_create_default_pool(self) -> None:
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
@ -26,9 +47,9 @@ class TestCreditPoolService:
assert pool.quota_used == 0
assert pool.quota_limit > 0
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session):
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=0)
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
@ -36,88 +57,191 @@ class TestCreditPoolService:
assert result.tenant_id == tenant_id
assert result.pool_type == ProviderQuotaType.TRIAL
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session):
@pytest.mark.usefixtures("flask_app_with_containers")
def test_get_pool_uses_configured_session_factory_without_flask_app_context(self) -> None:
tenant_id = self._create_tenant_id()
session_maker = session_factory.get_session_maker()
with session_maker.begin() as session:
session.add(
TenantCreditPool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL,
quota_limit=10,
quota_used=2,
)
)
assert not has_app_context()
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
assert result is not None
assert result.tenant_id == tenant_id
assert result.pool_type == ProviderQuotaType.TRIAL
assert result.quota_used == 2
@pytest.mark.usefixtures("flask_app_with_containers")
def test_get_pool_returns_none_when_not_exists(self) -> None:
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
assert result is None
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session):
@pytest.mark.usefixtures("flask_app_with_containers")
def test_check_credits_available_returns_false_when_no_pool(self) -> None:
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
assert result is False
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session):
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=0)
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10)
assert result is True
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session):
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
# Exhaust credits
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10)
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1)
assert result is False
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session):
@pytest.mark.usefixtures("flask_app_with_containers")
def test_check_and_deduct_credits_raises_when_no_pool(self) -> None:
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=1)
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session):
def test_check_and_deduct_credits_returns_zero_for_non_positive_request(
self, db_session_with_containers: Session
) -> None:
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
pool.quota_used = pool.quota_limit
db_session_with_containers.commit()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=0)
assert result == 0
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool is not None
assert updated_pool.quota_used == 2
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10)
with pytest.raises(QuotaExceededError, match="No credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session):
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool is not None
assert updated_pool.quota_used == 10
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
CreditPoolService.create_default_pool(tenant_id)
credits_required = 10
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
credits_required = 3
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required)
assert result == credits_required
db_session_with_containers.expire_all()
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
assert pool is not None
assert pool.quota_used == 5
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
self, db_session_with_containers: Session
):
) -> None:
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_used = pool.quota_used
db_session_with_containers.commit()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=9)
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == quota_used
assert updated_pool is not None
assert updated_pool.quota_used == 9
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors(
self, db_session_with_containers: Session
) -> None:
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_limit = pool.quota_limit
db_session_with_containers.commit()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
with (
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == quota_limit
assert updated_pool is not None
assert updated_pool.quota_used == 2
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=9)
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
assert result == 1
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool is not None
assert updated_pool.quota_used == 10
def test_deduct_credits_capped_returns_zero_for_non_positive_request(
self, db_session_with_containers: Session
) -> None:
tenant_id = self._create_tenant_id()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=0)
assert result == 0
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool is not None
assert updated_pool.quota_used == 2
@pytest.mark.usefixtures("flask_app_with_containers")
def test_deduct_credits_capped_returns_zero_when_no_pool(self) -> None:
result = CreditPoolService.deduct_credits_capped(tenant_id=self._create_tenant_id(), credits_required=1)
assert result == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_empty(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10)
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert result == 0
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool is not None
assert updated_pool.quota_used == 10
def test_deduct_credits_capped_wraps_unexpected_deduction_errors(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
with (
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool is not None
assert updated_pool.quota_used == 2
def test_deduct_credits_capped_reraises_quota_exceeded_errors(self, db_session_with_containers: Session) -> None:
tenant_id = self._create_tenant_id()
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
with (
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
pytest.raises(QuotaExceededError, match="quota unavailable"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool is not None
assert updated_pool.quota_used == 2

View File

@ -54,14 +54,12 @@ class TestTriggerProviderService:
def _create_test_account_and_tenant(
self,
db_session_with_containers: Session,
mock_external_service_dependencies: MockExternalServiceDependencies,
) -> tuple[Account, Tenant]:
"""
Helper method to create a test account and tenant for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
Returns:
@ -166,9 +164,7 @@ class TestTriggerProviderService:
- Database state is correctly updated
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -266,9 +262,7 @@ class TestTriggerProviderService:
- Merged credentials contain only new values
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -326,9 +320,7 @@ class TestTriggerProviderService:
- Original credentials are preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -384,9 +376,7 @@ class TestTriggerProviderService:
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -444,9 +434,7 @@ class TestTriggerProviderService:
- Original subscription state is preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -486,8 +474,9 @@ class TestTriggerProviderService:
assert subscription.name == original_name
assert subscription.parameters == original_parameters
@pytest.mark.usefixtures("db_session_with_containers")
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies
self, mock_external_service_dependencies: MockExternalServiceDependencies
) -> None:
"""
Test error when subscription is not found.
@ -496,9 +485,7 @@ class TestTriggerProviderService:
- Proper error is raised when subscription doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
fake_subscription_id = fake.uuid4()
@ -522,9 +509,7 @@ class TestTriggerProviderService:
- Error is raised when new name conflicts with existing subscription
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY

View File

@ -1,178 +0,0 @@
from collections.abc import Generator
from contextlib import contextmanager
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
pool_id = str(uuid4())
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": quota_limit,
"quota_used": quota_used,
},
)
return engine, tenant_id, pool_id
@contextmanager
def _patched_session_factory(engine: Engine) -> Generator[None, None, None]:
session_maker = sessionmaker(bind=engine, expire_on_commit=False)
with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker):
yield
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
with engine.connect() as connection:
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
def test_get_pool_uses_configured_session_factory_without_flask_app_context() -> None:
engine, tenant_id, _ = _create_engine_with_pool(quota_limit=10, quota_used=2)
with _patched_session_factory(engine):
pool = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
assert pool is not None
assert pool.tenant_id == tenant_id
assert pool.quota_used == 2
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with _patched_session_factory(engine):
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 3
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with (
_patched_session_factory(engine),
pytest.raises(QuotaExceededError, match="Credit pool not found"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with (
_patched_session_factory(engine),
pytest.raises(QuotaExceededError, match="No credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with (
_patched_session_factory(engine),
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
_patched_session_factory(engine),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with _patched_session_factory(engine):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
assert deducted_credits == 0
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
with _patched_session_factory(engine):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert deducted_credits == 0
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with _patched_session_factory(engine):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 1
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
_patched_session_factory(engine),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with (
_patched_session_factory(engine),
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
pytest.raises(QuotaExceededError, match="quota unavailable"),
):
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2