From 8f9127043e33c5c8781d5607944743869eb53aea Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 7 May 2026 18:15:21 +0800 Subject: [PATCH] fix(quota): caps free quota Signed-off-by: -LAN- --- api/core/app/llm/quota.py | 64 +++++++++++++------ api/services/credit_pool_service.py | 10 ++- .../services/test_credit_pool_service.py | 8 ++- .../unit_tests/core/app/test_llm_quota.py | 54 ++++++++++++++++ .../services/test_credit_pool_service.py | 40 ++++++++++++ 5 files changed, 152 insertions(+), 24 deletions(-) create mode 100644 api/tests/unit_tests/services/test_credit_pool_service.py diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index aa3ee38e83..64a9229ec8 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -8,7 +8,7 @@ with a non-LLM model. import warnings -from sqlalchemy import update +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from configs import dify_config @@ -72,6 +72,45 @@ def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage return used_quota +def _deduct_free_llm_quota( + *, + tenant_id: str, + provider: str, + quota_type: ProviderQuotaType, + used_quota: int, +) -> None: + """Deduct FREE provider quota, capping at the limit before reporting exhaustion.""" + quota_exceeded = False + with sessionmaker(bind=db.engine).begin() as session: + provider_record = session.scalar( + select(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == quota_type, + ) + .with_for_update() + ) + if ( + provider_record is None + or provider_record.quota_limit is None + or provider_record.quota_used is None + or provider_record.quota_limit <= provider_record.quota_used + ): + quota_exceeded = True + else: + available_quota = provider_record.quota_limit - provider_record.quota_used + deducted_quota = min(used_quota, available_quota) + provider_record.quota_used += deducted_quota + provider_record.last_used = naive_utc_now() + quota_exceeded = deducted_quota < used_quota + + if quota_exceeded: + raise QuotaExceededError(f"Model provider {provider} quota exceeded.") + + def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None: """Apply a resolved LLM quota charge against the current provider quota bucket.""" if provider_configuration.using_provider_type != ProviderType.SYSTEM: @@ -96,23 +135,12 @@ def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configurat pool_type="paid", ) case ProviderQuotaType.FREE: - with sessionmaker(bind=db.engine).begin() as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) - ) - session.execute(stmt) + _deduct_free_llm_quota( + tenant_id=tenant_id, + provider=provider, + quota_type=system_configuration.current_quota_type, + used_quota=used_quota, + ) case _: return diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 2d210db121..9d57cc0ac9 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -59,7 +59,7 @@ class CreditPoolService: credits_required: int, pool_type: str = "trial", ) -> int: - """check and deduct credits, returns actual credits deducted""" + """Deduct credits, depleting the pool before raising if the balance is insufficient.""" pool = cls.get_pool(tenant_id, pool_type) if not pool: @@ -68,8 +68,9 @@ class CreditPoolService: if pool.remaining_credits <= 0: raise QuotaExceededError("No credits remaining") - # deduct all remaining credits if less than required - actual_credits = min(credits_required, pool.remaining_credits) + remaining_credits = pool.remaining_credits + actual_credits = min(credits_required, remaining_credits) + quota_exceeded = actual_credits < credits_required try: with sessionmaker(db.engine).begin() as session: @@ -86,4 +87,7 @@ class CreditPoolService: logger.exception("Failed to deduct credits for tenant %s", tenant_id) raise QuotaExceededError("Failed to deduct credits") + if quota_exceeded: + raise QuotaExceededError("Insufficient credits remaining") + return actual_credits diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 09ba041244..02846acdb7 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -90,16 +90,18 @@ class TestCreditPoolService: pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert pool.quota_used == credits_required - def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session): + def test_check_and_deduct_credits_depletes_and_raises_when_insufficient( + self, db_session_with_containers: Session + ): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) remaining = 5 pool.quota_used = pool.quota_limit - remaining db_session_with_containers.commit() - result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) - assert result == remaining db_session_with_containers.expire_all() updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert updated_pool.quota_used == pool.quota_limit diff --git a/api/tests/unit_tests/core/app/test_llm_quota.py b/api/tests/unit_tests/core/app/test_llm_quota.py index de6aa1ec9f..7275d6b512 100644 --- a/api/tests/unit_tests/core/app/test_llm_quota.py +++ b/api/tests/unit_tests/core/app/test_llm_quota.py @@ -350,6 +350,7 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None: with ( patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."), ): deduct_llm_quota_for_model( tenant_id="tenant-id", @@ -364,6 +365,59 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None: assert exhausted_quota_used == 13 +def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 3 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.FREE, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + engine = create_engine("sqlite:///:memory:") + Provider.__table__.create(engine) + with engine.begin() as connection: + connection.execute( + Provider.__table__.insert(), + { + "id": "matching-provider", + "tenant_id": "tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 15, + "quota_used": 13, + "is_valid": True, + }, + ) + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."), + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + with engine.connect() as connection: + quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider")) + + assert quota_used == 15 + + def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None: usage = LLMUsage.empty_usage() usage.total_tokens = 2 diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py new file mode 100644 index 0000000000..8d30f4056d --- /dev/null +++ b/api/tests/unit_tests/services/test_credit_pool_service.py @@ -0,0 +1,40 @@ +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import create_engine, select + +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from models.enums import ProviderQuotaType +from services.credit_pool_service import CreditPoolService + + +def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -> None: + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + tenant_id = str(uuid4()) + pool_id = str(uuid4()) + with engine.begin() as connection: + connection.execute( + TenantCreditPool.__table__.insert(), + { + "id": pool_id, + "tenant_id": tenant_id, + "pool_type": ProviderQuotaType.TRIAL, + "quota_limit": 10, + "quota_used": 9, + }, + ) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="Insufficient credits remaining"), + ): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) + + with engine.connect() as connection: + quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) + + assert quota_used == 10