fix(api): cap workflow credit deductions

This commit is contained in:
-LAN- 2026-05-08 16:11:52 +08:00
parent a8bb64bcf8
commit c0a907b34f
2 changed files with 60 additions and 9 deletions

View File

@ -122,14 +122,14 @@ def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configurat
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",

View File

@ -16,6 +16,8 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from models import TenantCreditPool
from models.enums import ProviderQuotaType as ModelProviderQuotaType
from models.provider import Provider, ProviderType
@ -97,7 +99,7 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
@ -112,6 +114,55 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
)
def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": "trial-pool",
"tenant_id": "tenant-id",
"pool_type": ModelProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
},
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == "trial-pool"))
assert quota_used == 10
def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
@ -133,7 +184,7 @@ def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None:
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
@ -166,7 +217,7 @@ def test_deduct_llm_quota_for_model_uses_credit_configuration() -> None:
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch.object(type(dify_config), "get_model_credits", return_value=9) as mock_get_model_credits,
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
@ -202,7 +253,7 @@ def test_deduct_llm_quota_for_model_uses_single_charge_for_times_quota() -> None
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
@ -238,7 +289,7 @@ def test_deduct_llm_quota_for_model_uses_paid_billing_pool() -> None:
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
@ -439,7 +490,7 @@ def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None:
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
):
deduct_llm_quota_for_model(
@ -468,7 +519,7 @@ def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> N
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits,
patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker,
):
deduct_llm_quota_for_model(