mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 18:24:09 +08:00
test: migrate credit pool service tests to Testcontainers (#37252)
This commit is contained in:
parent
3fb1d3055e
commit
212b819f1c
@ -2,18 +2,115 @@ extend = "../../.ruff.toml"
|
||||
src = ["../.."]
|
||||
|
||||
[lint]
|
||||
extend-select = ["ANN401", "TID251"]
|
||||
extend-select = ["ANN401", "ARG", "TID251"]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"**/*.py" = ["S110", "T201"]
|
||||
"core/rag/pipeline/test_queue_integration.py" = ["ANN401", "TID251"]
|
||||
"core/rag/pipeline/test_queue_integration.py" = ["ANN401", "TID251", "ARG"]
|
||||
"models/test_types_enum_text.py" = ["ANN401", "TID251"]
|
||||
"services/test_app_dsl_service.py" = ["ANN401", "TID251"]
|
||||
"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251"]
|
||||
"services/test_app_dsl_service.py" = ["ANN401", "TID251", "ARG"]
|
||||
"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251", "ARG"]
|
||||
"services/test_hit_testing_service.py" = ["ANN401", "TID251"]
|
||||
"services/test_recommended_app_service.py" = ["ANN401", "TID251"]
|
||||
"services/test_recommended_app_service.py" = ["ANN401", "TID251", "ARG"]
|
||||
"trigger/conftest.py" = ["ANN401", "TID251"]
|
||||
"trigger/test_trigger_e2e.py" = ["ANN401", "TID251"]
|
||||
"trigger/test_trigger_e2e.py" = ["ANN401", "TID251", "ARG"]
|
||||
"controllers/console/app/test_app_apis.py" = ["ARG"]
|
||||
"controllers/console/app/test_app_import_api.py" = ["ARG"]
|
||||
"controllers/console/auth/test_oauth.py" = ["ARG"]
|
||||
"controllers/console/auth/test_password_reset.py" = ["ARG"]
|
||||
"controllers/console/datasets/test_data_source.py" = ["ARG"]
|
||||
"controllers/console/test_apikey.py" = ["ARG"]
|
||||
"controllers/console/workspace/test_tool_provider.py" = ["ARG"]
|
||||
"controllers/mcp/test_mcp.py" = ["ARG"]
|
||||
"controllers/service_api/dataset/test_dataset.py" = ["ARG"]
|
||||
"controllers/web/test_conversation.py" = ["ARG"]
|
||||
"controllers/web/test_human_input_form.py" = ["ARG"]
|
||||
"controllers/web/test_wraps.py" = ["ARG"]
|
||||
"core/app/layers/test_pause_state_persist_layer.py" = ["ARG"]
|
||||
"core/rag/retrieval/test_dataset_retrieval_integration.py" = ["ARG"]
|
||||
"models/test_account.py" = ["ARG"]
|
||||
"models/test_conversation_message_inputs.py" = ["ARG"]
|
||||
"models/test_conversation_status_count.py" = ["ARG"]
|
||||
"repositories/test_sqlalchemy_api_workflow_run_repository.py" = ["ARG"]
|
||||
"repositories/test_workflow_run_repository.py" = ["ARG"]
|
||||
"services/auth/test_api_key_auth_service.py" = ["ARG"]
|
||||
"services/auth/test_auth_integration.py" = ["ARG"]
|
||||
"services/dataset_collection_binding.py" = ["ARG"]
|
||||
"services/dataset_service_update_delete.py" = ["ARG"]
|
||||
"services/document_service_status.py" = ["ARG"]
|
||||
"services/enterprise/test_account_deletion_sync.py" = ["ARG"]
|
||||
"services/plugin/test_plugin_parameter_service.py" = ["ARG"]
|
||||
"services/plugin/test_plugin_service.py" = ["ARG"]
|
||||
"services/rag_pipeline/test_rag_pipeline_service_db.py" = ["ARG"]
|
||||
"services/recommend_app/test_database_retrieval.py" = ["ARG"]
|
||||
"services/test_account_service.py" = ["ARG"]
|
||||
"services/test_advanced_prompt_template_service.py" = ["ARG"]
|
||||
"services/test_annotation_service.py" = ["ARG"]
|
||||
"services/test_api_based_extension_service.py" = ["ARG"]
|
||||
"services/test_api_token_service.py" = ["ARG"]
|
||||
"services/test_app_generate_service.py" = ["ARG"]
|
||||
"services/test_app_service.py" = ["ARG"]
|
||||
"services/test_attachment_service.py" = ["ARG"]
|
||||
"services/test_conversation_variable_updater.py" = ["ARG"]
|
||||
"services/test_dataset_permission_service.py" = ["ARG"]
|
||||
"services/test_dataset_service_batch_update_document_status.py" = ["ARG"]
|
||||
"services/test_dataset_service_retrieval.py" = ["ARG"]
|
||||
"services/test_delete_archived_workflow_run.py" = ["ARG"]
|
||||
"services/test_document_service_rename_document.py" = ["ARG"]
|
||||
"services/test_end_user_service.py" = ["ARG"]
|
||||
"services/test_feature_service.py" = ["ARG"]
|
||||
"services/test_feedback_service.py" = ["ARG"]
|
||||
"services/test_file_service.py" = ["ARG"]
|
||||
"services/test_human_input_delivery_test_service.py" = ["ARG"]
|
||||
"services/test_message_service.py" = ["ARG"]
|
||||
"services/test_messages_clean_service.py" = ["ARG", "S110"]
|
||||
"services/test_metadata_partial_update.py" = ["ARG"]
|
||||
"services/test_metadata_service.py" = ["ARG"]
|
||||
"services/test_model_load_balancing_service.py" = ["ARG"]
|
||||
"services/test_model_provider_service.py" = ["ARG"]
|
||||
"services/test_oauth_server_service.py" = ["ARG"]
|
||||
"services/test_ops_service.py" = ["ARG"]
|
||||
"services/test_saved_message_service.py" = ["ARG"]
|
||||
"services/test_web_conversation_service.py" = ["ARG"]
|
||||
"services/test_webapp_auth_service.py" = ["ARG"]
|
||||
"services/test_webhook_service.py" = ["ARG"]
|
||||
"services/test_workflow_app_service.py" = ["ARG"]
|
||||
"services/test_workflow_draft_variable_service.py" = ["ARG"]
|
||||
"services/test_workflow_run_service.py" = ["ARG"]
|
||||
"services/test_workflow_service.py" = ["ARG"]
|
||||
"services/test_workspace_service.py" = ["ARG"]
|
||||
"services/tools/test_api_tools_manage_service.py" = ["ARG"]
|
||||
"services/tools/test_mcp_tools_manage_service.py" = ["ARG"]
|
||||
"services/tools/test_tools_transform_service.py" = ["ARG"]
|
||||
"services/workflow/test_workflow_converter.py" = ["ARG"]
|
||||
"tasks/test_add_document_to_index_task.py" = ["ARG"]
|
||||
"tasks/test_batch_clean_document_task.py" = ["ARG"]
|
||||
"tasks/test_batch_create_segment_to_index_task.py" = ["ARG"]
|
||||
"tasks/test_clean_dataset_task.py" = ["T201"]
|
||||
"tasks/test_clean_notion_document_task.py" = ["ARG"]
|
||||
"tasks/test_create_segment_to_index_task.py" = ["ARG"]
|
||||
"tasks/test_dataset_indexing_task.py" = ["ARG"]
|
||||
"tasks/test_deal_dataset_vector_index_task.py" = ["ARG"]
|
||||
"tasks/test_delete_segment_from_index_task.py" = ["ARG"]
|
||||
"tasks/test_disable_segment_from_index_task.py" = ["ARG"]
|
||||
"tasks/test_disable_segments_from_index_task.py" = ["ARG"]
|
||||
"tasks/test_document_indexing_sync_task.py" = ["ARG"]
|
||||
"tasks/test_document_indexing_task.py" = ["ARG"]
|
||||
"tasks/test_document_indexing_update_task.py" = ["ARG"]
|
||||
"tasks/test_duplicate_document_indexing_task.py" = ["ARG"]
|
||||
"tasks/test_enable_segments_to_index_task.py" = ["ARG"]
|
||||
"tasks/test_mail_change_mail_task.py" = ["ARG"]
|
||||
"tasks/test_mail_email_code_login_task.py" = ["ARG"]
|
||||
"tasks/test_mail_human_input_delivery_task.py" = ["ARG"]
|
||||
"tasks/test_mail_inner_task.py" = ["ARG"]
|
||||
"tasks/test_mail_invite_member_task.py" = ["ARG"]
|
||||
"tasks/test_mail_owner_transfer_task.py" = ["ARG"]
|
||||
"tasks/test_mail_register_task.py" = ["ARG"]
|
||||
"tasks/test_rag_pipeline_run_tasks.py" = ["ARG"]
|
||||
"test_workflow_pause_integration.py" = ["T201"]
|
||||
"workflow/nodes/code_executor/test_code_javascript.py" = ["ARG"]
|
||||
"workflow/nodes/code_executor/test_code_jinja2.py" = ["ARG"]
|
||||
"workflow/nodes/code_executor/test_code_python3.py" = ["ARG"]
|
||||
"workflow/nodes/code_executor/test_utils.py" = ["T201"]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."typing.Any"]
|
||||
msg = "Use object, Protocol, TypedDict, TypeVar, ParamSpec, or a localized cast instead."
|
||||
|
||||
@ -424,6 +424,7 @@ def flask_app_with_containers(set_up_containers_and_env: DifyTestContainers) ->
|
||||
Returns:
|
||||
Flask: Configured Flask application
|
||||
"""
|
||||
assert set_up_containers_and_env is _container_manager
|
||||
logger.info("=== Creating session-scoped Flask application ===")
|
||||
app = _create_app_with_containers()
|
||||
logger.info("Session-scoped Flask application created successfully")
|
||||
|
||||
@ -83,7 +83,6 @@ project-excludes = [
|
||||
"services/test_conversation_service.py",
|
||||
"services/test_conversation_service_variables.py",
|
||||
"services/test_conversation_variable_updater.py",
|
||||
"services/test_credit_pool_service.py",
|
||||
"services/test_dataset_permission_service.py",
|
||||
"services/test_dataset_service.py",
|
||||
"services/test_dataset_service_batch_update_document_status.py",
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -45,7 +46,8 @@ class TestGetPermission:
|
||||
assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
def test_returns_none_when_not_found(self, db_session_with_containers: Session) -> None:
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_returns_none_when_not_found(self) -> None:
|
||||
result = PluginPermissionService.get_permission(_tenant_id())
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
"""Testcontainers integration tests for CreditPoolService."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import has_app_context
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
@ -15,7 +18,25 @@ class TestCreditPoolService:
|
||||
def _create_tenant_id(self) -> str:
|
||||
return str(uuid4())
|
||||
|
||||
def test_create_default_pool(self, db_session_with_containers: Session):
|
||||
def _create_pool(
|
||||
self,
|
||||
db_session: Session,
|
||||
*,
|
||||
tenant_id: str,
|
||||
quota_limit: int = 10,
|
||||
quota_used: int = 0,
|
||||
) -> None:
|
||||
pool = TenantCreditPool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
quota_limit=quota_limit,
|
||||
quota_used=quota_used,
|
||||
)
|
||||
db_session.add(pool)
|
||||
db_session.commit()
|
||||
|
||||
@pytest.mark.usefixtures("db_session_with_containers")
|
||||
def test_create_default_pool(self) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
@ -26,9 +47,9 @@ class TestCreditPoolService:
|
||||
assert pool.quota_used == 0
|
||||
assert pool.quota_limit > 0
|
||||
|
||||
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session):
|
||||
def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=0)
|
||||
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
@ -36,88 +57,191 @@ class TestCreditPoolService:
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.pool_type == ProviderQuotaType.TRIAL
|
||||
|
||||
def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session):
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_get_pool_uses_configured_session_factory_without_flask_app_context(self) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
session_maker = session_factory.get_session_maker()
|
||||
with session_maker.begin() as session:
|
||||
session.add(
|
||||
TenantCreditPool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=ProviderQuotaType.TRIAL,
|
||||
quota_limit=10,
|
||||
quota_used=2,
|
||||
)
|
||||
)
|
||||
|
||||
assert not has_app_context()
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert result is not None
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.pool_type == ProviderQuotaType.TRIAL
|
||||
assert result.quota_used == 2
|
||||
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_get_pool_returns_none_when_not_exists(self) -> None:
|
||||
result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session):
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_check_credits_available_returns_false_when_no_pool(self) -> None:
|
||||
result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session):
|
||||
def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=0)
|
||||
|
||||
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session):
|
||||
def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
# Exhaust credits
|
||||
pool.quota_used = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10)
|
||||
|
||||
result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session):
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_check_and_deduct_credits_raises_when_no_pool(self) -> None:
|
||||
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10)
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=1)
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session):
|
||||
def test_check_and_deduct_credits_returns_zero_for_non_positive_request(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
pool.quota_used = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=0)
|
||||
|
||||
assert result == 0
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 2
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10)
|
||||
|
||||
with pytest.raises(QuotaExceededError, match="No credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10)
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session):
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 10
|
||||
|
||||
def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
CreditPoolService.create_default_pool(tenant_id)
|
||||
credits_required = 10
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
|
||||
credits_required = 3
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required)
|
||||
|
||||
assert result == credits_required
|
||||
db_session_with_containers.expire_all()
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
assert pool is not None
|
||||
assert pool.quota_used == 5
|
||||
|
||||
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_used = pool.quota_used
|
||||
db_session_with_containers.commit()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=9)
|
||||
|
||||
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == quota_used
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 9
|
||||
|
||||
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
|
||||
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_limit = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
|
||||
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
|
||||
with (
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert result == remaining
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == quota_limit
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 2
|
||||
|
||||
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=9)
|
||||
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert result == 1
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 10
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_for_non_positive_request(
|
||||
self, db_session_with_containers: Session
|
||||
) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
|
||||
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=0)
|
||||
|
||||
assert result == 0
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 2
|
||||
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_deduct_credits_capped_returns_zero_when_no_pool(self) -> None:
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=self._create_tenant_id(), credits_required=1)
|
||||
|
||||
assert result == 0
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_empty(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=10)
|
||||
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert result == 0
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 10
|
||||
|
||||
def test_deduct_credits_capped_wraps_unexpected_deduction_errors(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 2
|
||||
|
||||
def test_deduct_credits_capped_reraises_quota_exceeded_errors(self, db_session_with_containers: Session) -> None:
|
||||
tenant_id = self._create_tenant_id()
|
||||
self._create_pool(db_session_with_containers, tenant_id=tenant_id, quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="quota unavailable"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool is not None
|
||||
assert updated_pool.quota_used == 2
|
||||
|
||||
@ -54,14 +54,12 @@ class TestTriggerProviderService:
|
||||
|
||||
def _create_test_account_and_tenant(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies: MockExternalServiceDependencies,
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
|
||||
Returns:
|
||||
@ -166,9 +164,7 @@ class TestTriggerProviderService:
|
||||
- Database state is correctly updated
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -266,9 +262,7 @@ class TestTriggerProviderService:
|
||||
- Merged credentials contain only new values
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -326,9 +320,7 @@ class TestTriggerProviderService:
|
||||
- Original credentials are preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -384,9 +376,7 @@ class TestTriggerProviderService:
|
||||
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -444,9 +434,7 @@ class TestTriggerProviderService:
|
||||
- Original subscription state is preserved
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
@ -486,8 +474,9 @@ class TestTriggerProviderService:
|
||||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
@pytest.mark.usefixtures("db_session_with_containers")
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies
|
||||
self, mock_external_service_dependencies: MockExternalServiceDependencies
|
||||
) -> None:
|
||||
"""
|
||||
Test error when subscription is not found.
|
||||
@ -496,9 +485,7 @@ class TestTriggerProviderService:
|
||||
- Proper error is raised when subscription doesn't exist
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
fake_subscription_id = fake.uuid4()
|
||||
@ -522,9 +509,7 @@ class TestTriggerProviderService:
|
||||
- Error is raised when new name conflicts with existing subscription
|
||||
"""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
|
||||
|
||||
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
|
||||
credential_type = CredentialType.API_KEY
|
||||
|
||||
@ -1,178 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
pool_id = str(uuid4())
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": quota_limit,
|
||||
"quota_used": quota_used,
|
||||
},
|
||||
)
|
||||
return engine, tenant_id, pool_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patched_session_factory(engine: Engine) -> Generator[None, None, None]:
|
||||
session_maker = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
with patch("services.credit_pool_service.session_factory.get_session_maker", return_value=session_maker):
|
||||
yield
|
||||
|
||||
|
||||
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
|
||||
with engine.connect() as connection:
|
||||
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
|
||||
def test_get_pool_uses_configured_session_factory_without_flask_app_context() -> None:
|
||||
engine, tenant_id, _ = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with _patched_session_factory(engine):
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=ProviderQuotaType.TRIAL)
|
||||
|
||||
assert pool is not None
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.quota_used == 2
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 3
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with (
|
||||
_patched_session_factory(engine),
|
||||
pytest.raises(QuotaExceededError, match="Credit pool not found"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with (
|
||||
_patched_session_factory(engine),
|
||||
pytest.raises(QuotaExceededError, match="No credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with (
|
||||
_patched_session_factory(engine),
|
||||
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
_patched_session_factory(engine),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None:
|
||||
assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
|
||||
|
||||
def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10)
|
||||
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert deducted_credits == 0
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with _patched_session_factory(engine):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 1
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
|
||||
def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
_patched_session_factory(engine),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="Failed to deduct credits"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
|
||||
|
||||
def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with (
|
||||
_patched_session_factory(engine),
|
||||
patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")),
|
||||
pytest.raises(QuotaExceededError, match="quota unavailable"),
|
||||
):
|
||||
CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1)
|
||||
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 2
|
||||
Loading…
Reference in New Issue
Block a user