From c0a907b34fa8e766d0ff0852d34fa1ae547c18c6 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 8 May 2026 16:11:52 +0800 Subject: [PATCH] fix(api): cap workflow credit deductions --- api/core/app/llm/quota.py | 4 +- .../unit_tests/core/app/test_llm_quota.py | 65 +++++++++++++++++-- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 64a9229ec8..5bf3334a7b 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -122,14 +122,14 @@ def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configurat case ProviderQuotaType.TRIAL: from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( + CreditPoolService.deduct_credits_capped( tenant_id=tenant_id, credits_required=used_quota, ) case ProviderQuotaType.PAID: from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( + CreditPoolService.deduct_credits_capped( tenant_id=tenant_id, credits_required=used_quota, pool_type="paid", diff --git a/api/tests/unit_tests/core/app/test_llm_quota.py b/api/tests/unit_tests/core/app/test_llm_quota.py index 7275d6b512..d9390a4a8f 100644 --- a/api/tests/unit_tests/core/app/test_llm_quota.py +++ b/api/tests/unit_tests/core/app/test_llm_quota.py @@ -16,6 +16,8 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.model_entities import ModelType +from models import TenantCreditPool +from models.enums import ProviderQuotaType as ModelProviderQuotaType from models.provider import Provider, ProviderType @@ -97,7 +99,7 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None: with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, ): deduct_llm_quota_for_model( tenant_id="tenant-id", @@ -112,6 +114,55 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None: ) +def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 3 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + with engine.begin() as connection: + connection.execute( + TenantCreditPool.__table__.insert(), + { + "id": "trial-pool", + "tenant_id": "tenant-id", + "pool_type": ModelProviderQuotaType.TRIAL, + "quota_limit": 10, + "quota_used": 9, + }, + ) + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + with engine.connect() as connection: + quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == "trial-pool")) + + assert quota_used == 10 + + def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None: usage = LLMUsage.empty_usage() usage.total_tokens = 42 @@ -133,7 +184,7 @@ def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None: with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, ): deduct_llm_quota_for_model( tenant_id="tenant-id", @@ -166,7 +217,7 @@ def test_deduct_llm_quota_for_model_uses_credit_configuration() -> None: with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), patch.object(type(dify_config), "get_model_credits", return_value=9) as mock_get_model_credits, - patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, ): deduct_llm_quota_for_model( tenant_id="tenant-id", @@ -202,7 +253,7 @@ def test_deduct_llm_quota_for_model_uses_single_charge_for_times_quota() -> None with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, ): deduct_llm_quota_for_model( tenant_id="tenant-id", @@ -238,7 +289,7 @@ def test_deduct_llm_quota_for_model_uses_paid_billing_pool() -> None: with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, ): deduct_llm_quota_for_model( tenant_id="tenant-id", @@ -439,7 +490,7 @@ def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None: with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker, ): deduct_llm_quota_for_model( @@ -468,7 +519,7 @@ def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> N with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), - patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker, ): deduct_llm_quota_for_model(