fix(quota): caps free quota

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2026-05-07 18:15:21 +08:00
parent b66c6760ad
commit 19171e0eb8
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
5 changed files with 152 additions and 24 deletions

View File

@ -8,7 +8,7 @@ with a non-LLM model.
import warnings
from sqlalchemy import update
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
@ -72,6 +72,45 @@ def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage
return used_quota
def _deduct_free_llm_quota(
*,
tenant_id: str,
provider: str,
quota_type: ProviderQuotaType,
used_quota: int,
) -> None:
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
quota_exceeded = False
with sessionmaker(bind=db.engine).begin() as session:
provider_record = session.scalar(
select(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota_type,
)
.with_for_update()
)
if (
provider_record is None
or provider_record.quota_limit is None
or provider_record.quota_used is None
or provider_record.quota_limit <= provider_record.quota_used
):
quota_exceeded = True
else:
available_quota = provider_record.quota_limit - provider_record.quota_used
deducted_quota = min(used_quota, available_quota)
provider_record.quota_used += deducted_quota
provider_record.last_used = naive_utc_now()
quota_exceeded = deducted_quota < used_quota
if quota_exceeded:
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
@ -96,23 +135,12 @@ def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configurat
pool_type="paid",
)
case ProviderQuotaType.FREE:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
_deduct_free_llm_quota(
tenant_id=tenant_id,
provider=provider,
quota_type=system_configuration.current_quota_type,
used_quota=used_quota,
)
case _:
return

View File

@ -59,7 +59,7 @@ class CreditPoolService:
credits_required: int,
pool_type: str = "trial",
) -> int:
"""check and deduct credits, returns actual credits deducted"""
"""Deduct credits, depleting the pool before raising if the balance is insufficient."""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
@ -68,8 +68,9 @@ class CreditPoolService:
if pool.remaining_credits <= 0:
raise QuotaExceededError("No credits remaining")
# deduct all remaining credits if less than required
actual_credits = min(credits_required, pool.remaining_credits)
remaining_credits = pool.remaining_credits
actual_credits = min(credits_required, remaining_credits)
quota_exceeded = actual_credits < credits_required
try:
with sessionmaker(db.engine).begin() as session:
@ -86,4 +87,7 @@ class CreditPoolService:
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")
if quota_exceeded:
raise QuotaExceededError("Insufficient credits remaining")
return actual_credits

View File

@ -90,16 +90,18 @@ class TestCreditPoolService:
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
def test_check_and_deduct_credits_depletes_and_raises_when_insufficient(
self, db_session_with_containers: Session
):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
db_session_with_containers.commit()
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == pool.quota_limit

View File

@ -350,6 +350,7 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
@ -364,6 +365,59 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
assert exhausted_quota_used == 13
def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 3
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.FREE,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.FREE,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
provider_manager = MagicMock()
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
engine = create_engine("sqlite:///:memory:")
Provider.__table__.create(engine)
with engine.begin() as connection:
connection.execute(
Provider.__table__.insert(),
{
"id": "matching-provider",
"tenant_id": "tenant-id",
"provider_name": "openai",
"provider_type": ProviderType.SYSTEM,
"quota_type": ProviderQuotaType.FREE,
"quota_limit": 15,
"quota_used": 13,
"is_valid": True,
},
)
with (
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
assert quota_used == 15
def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 2

View File

@ -0,0 +1,40 @@
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine, select
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
pool_id = str(uuid4())
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
},
)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
assert quota_used == 10