mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
fix(quota): caps free quota
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
f4370bf4f6
commit
8f9127043e
@ -8,7 +8,7 @@ with a non-LLM model.
|
||||
|
||||
import warnings
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
@ -72,6 +72,45 @@ def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage
|
||||
return used_quota
|
||||
|
||||
|
||||
def _deduct_free_llm_quota(
|
||||
*,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
quota_type: ProviderQuotaType,
|
||||
used_quota: int,
|
||||
) -> None:
|
||||
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
|
||||
quota_exceeded = False
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
provider_record = session.scalar(
|
||||
select(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == quota_type,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
if (
|
||||
provider_record is None
|
||||
or provider_record.quota_limit is None
|
||||
or provider_record.quota_used is None
|
||||
or provider_record.quota_limit <= provider_record.quota_used
|
||||
):
|
||||
quota_exceeded = True
|
||||
else:
|
||||
available_quota = provider_record.quota_limit - provider_record.quota_used
|
||||
deducted_quota = min(used_quota, available_quota)
|
||||
provider_record.quota_used += deducted_quota
|
||||
provider_record.last_used = naive_utc_now()
|
||||
quota_exceeded = deducted_quota < used_quota
|
||||
|
||||
if quota_exceeded:
|
||||
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
|
||||
|
||||
|
||||
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
|
||||
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
@ -96,23 +135,12 @@ def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configurat
|
||||
pool_type="paid",
|
||||
)
|
||||
case ProviderQuotaType.FREE:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
_deduct_free_llm_quota(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
quota_type=system_configuration.current_quota_type,
|
||||
used_quota=used_quota,
|
||||
)
|
||||
case _:
|
||||
return
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ class CreditPoolService:
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
) -> int:
|
||||
"""check and deduct credits, returns actual credits deducted"""
|
||||
"""Deduct credits, depleting the pool before raising if the balance is insufficient."""
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
@ -68,8 +68,9 @@ class CreditPoolService:
|
||||
if pool.remaining_credits <= 0:
|
||||
raise QuotaExceededError("No credits remaining")
|
||||
|
||||
# deduct all remaining credits if less than required
|
||||
actual_credits = min(credits_required, pool.remaining_credits)
|
||||
remaining_credits = pool.remaining_credits
|
||||
actual_credits = min(credits_required, remaining_credits)
|
||||
quota_exceeded = actual_credits < credits_required
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
@ -86,4 +87,7 @@ class CreditPoolService:
|
||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
if quota_exceeded:
|
||||
raise QuotaExceededError("Insufficient credits remaining")
|
||||
|
||||
return actual_credits
|
||||
|
||||
@ -90,16 +90,18 @@ class TestCreditPoolService:
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
|
||||
def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session):
|
||||
def test_check_and_deduct_credits_depletes_and_raises_when_insufficient(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
assert result == remaining
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == pool.quota_limit
|
||||
|
||||
@ -350,6 +350,7 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
@ -364,6 +365,59 @@ def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None:
|
||||
assert exhausted_quota_used == 13
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 3
|
||||
provider_configuration = SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.FREE,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.FREE,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=100,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value.get.return_value = provider_configuration
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Provider.__table__.create(engine)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
Provider.__table__.insert(),
|
||||
{
|
||||
"id": "matching-provider",
|
||||
"tenant_id": "tenant-id",
|
||||
"provider_name": "openai",
|
||||
"provider_type": ProviderType.SYSTEM,
|
||||
"quota_type": ProviderQuotaType.FREE,
|
||||
"quota_limit": 15,
|
||||
"quota_used": 13,
|
||||
"is_valid": True,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
|
||||
patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
|
||||
):
|
||||
deduct_llm_quota_for_model(
|
||||
tenant_id="tenant-id",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider"))
|
||||
|
||||
assert quota_used == 15
|
||||
|
||||
|
||||
def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.total_tokens = 2
|
||||
|
||||
40
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
40
api/tests/unit_tests/services/test_credit_pool_service.py
Normal file
@ -0,0 +1,40 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from models.enums import ProviderQuotaType
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
pool_id = str(uuid4())
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": 10,
|
||||
"quota_used": 9,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
pytest.raises(QuotaExceededError, match="Insufficient credits remaining"),
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
assert quota_used == 10
|
||||
Loading…
Reference in New Issue
Block a user