fix(api): split exact and capped credit deduction

This commit is contained in:
-LAN- 2026-05-08 15:39:30 +08:00
parent 6768fdd9f8
commit a8bb64bcf8
6 changed files with 206 additions and 41 deletions

View File

@ -132,6 +132,7 @@ class LLMQuotaLayer(GraphEngineLayer):
error_type=error_type,
)
# TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override.
node._run = quota_aborted_run # type: ignore[method-assign]
self._send_abort_command(reason=reason)

View File

@ -137,17 +137,13 @@ def handle(sender: Message, **kwargs):
if used_quota is not None:
match provider_configuration.system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
_deduct_credit_pool_quota_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="trial",
)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
_deduct_credit_pool_quota_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
@ -200,6 +196,26 @@ def handle(sender: Message, **kwargs):
raise
def _deduct_credit_pool_quota_capped(*, tenant_id: str, credits_required: int, pool_type: str) -> None:
"""Apply post-generation credit accounting without failing message persistence on quota exhaustion."""
from services.credit_pool_service import CreditPoolService
deducted_credits = CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=credits_required,
pool_type=pool_type,
)
if deducted_credits < credits_required:
logger.warning(
"Credit pool exhausted during message-created accounting, "
"tenant_id=%s, pool_type=%s, credits_required=%s, credits_deducted=%s",
tenant_id,
pool_type,
credits_required,
deducted_credits,
)
def _calculate_quota_usage(
*, message: Message, system_configuration: SystemConfiguration, model_name: str
) -> int | None:

View File

@ -1,7 +1,7 @@
import logging
from sqlalchemy import select, update
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.errors.error import QuotaExceededError
@ -13,6 +13,18 @@ logger = logging.getLogger(__name__)
class CreditPoolService:
@staticmethod
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
return session.scalar(
select(TenantCreditPool)
.where(
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.pool_type == pool_type,
)
.limit(1)
.with_for_update()
)
@classmethod
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
@ -59,35 +71,57 @@ class CreditPoolService:
credits_required: int,
pool_type: str = "trial",
) -> int:
"""Deduct credits, depleting the pool before raising if the balance is insufficient."""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
if pool.remaining_credits <= 0:
raise QuotaExceededError("No credits remaining")
remaining_credits = pool.remaining_credits
actual_credits = min(credits_required, remaining_credits)
quota_exceeded = actual_credits < credits_required
"""Deduct exactly the requested credits or raise without mutating the pool."""
if credits_required <= 0:
return 0
try:
with sessionmaker(db.engine).begin() as session:
stmt = (
update(TenantCreditPool)
.where(
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.pool_type == pool_type,
)
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
)
session.execute(stmt)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
remaining_credits = pool.remaining_credits
if remaining_credits <= 0:
raise QuotaExceededError("No credits remaining")
if remaining_credits < credits_required:
raise QuotaExceededError("Insufficient credits remaining")
pool.quota_used += credits_required
except QuotaExceededError:
raise
except Exception:
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")
if quota_exceeded:
raise QuotaExceededError("Insufficient credits remaining")
return credits_required
return actual_credits
@classmethod
def deduct_credits_capped(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
) -> int:
"""Deduct up to the available balance and return the actual deducted credits."""
if credits_required <= 0:
return 0
try:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
if not pool:
logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type)
return 0
deducted_credits = min(credits_required, pool.remaining_credits)
if deducted_credits <= 0:
return 0
pool.quota_used += deducted_credits
return deducted_credits
except QuotaExceededError:
raise
except Exception:
logger.exception("Failed to deduct capped credits for tenant %s", tenant_id)
raise QuotaExceededError("Failed to deduct credits")

View File

@ -90,11 +90,14 @@ class TestCreditPoolService:
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert pool.quota_used == credits_required
def test_check_and_deduct_credits_depletes_and_raises_when_insufficient(self, db_session_with_containers: Session):
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
self, db_session_with_containers: Session
):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_used = pool.quota_used
db_session_with_containers.commit()
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
@ -102,4 +105,19 @@ class TestCreditPoolService:
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == pool.quota_limit
assert updated_pool.quota_used == quota_used
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
tenant_id = self._create_tenant_id()
pool = CreditPoolService.create_default_pool(tenant_id)
remaining = 5
pool.quota_used = pool.quota_limit - remaining
quota_limit = pool.quota_limit
db_session_with_containers.commit()
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
assert result == remaining
db_session_with_containers.expire_all()
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
assert updated_pool.quota_used == quota_limit

View File

@ -0,0 +1,68 @@
from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
from sqlalchemy import create_engine, select
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from events.event_handlers import update_provider_when_message_created
from models import TenantCreditPool
from models.provider import ProviderType
def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
pool_id = str(uuid4())
with engine.begin() as connection:
connection.execute(
TenantCreditPool.__table__.insert(),
{
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
},
)
system_configuration = SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=10,
)
],
)
application_generate_entity = ChatAppGenerateEntity.model_construct(
app_config=SimpleNamespace(tenant_id=tenant_id),
model_conf=SimpleNamespace(
provider="openai",
model="gpt-4o",
provider_model_bundle=SimpleNamespace(
configuration=SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=system_configuration,
)
),
),
)
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
):
update_provider_when_message_created.handle(
sender=message,
application_generate_entity=application_generate_entity,
)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
assert quota_used == 10

View File

@ -4,6 +4,7 @@ from uuid import uuid4
import pytest
from sqlalchemy import create_engine, select
from sqlalchemy.engine import Engine
from core.errors.error import QuotaExceededError
from models import TenantCreditPool
@ -11,7 +12,7 @@ from models.enums import ProviderQuotaType
from services.credit_pool_service import CreditPoolService
def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -> None:
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
engine = create_engine("sqlite:///:memory:")
TenantCreditPool.__table__.create(engine)
tenant_id = str(uuid4())
@ -23,10 +24,30 @@ def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -
"id": pool_id,
"tenant_id": tenant_id,
"pool_type": ProviderQuotaType.TRIAL,
"quota_limit": 10,
"quota_used": 9,
"quota_limit": quota_limit,
"quota_used": quota_used,
},
)
return engine, tenant_id, pool_id
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
with engine.connect() as connection:
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 3
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with (
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
@ -34,7 +55,14 @@ def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -
):
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
with engine.connect() as connection:
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
assert quota_used == 10
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
assert deducted_credits == 1
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10