mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
fix(api): split exact and capped credit deduction
This commit is contained in:
parent
6768fdd9f8
commit
a8bb64bcf8
@ -132,6 +132,7 @@ class LLMQuotaLayer(GraphEngineLayer):
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
# TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override.
|
||||
node._run = quota_aborted_run # type: ignore[method-assign]
|
||||
self._send_abort_command(reason=reason)
|
||||
|
||||
|
||||
@ -137,17 +137,13 @@ def handle(sender: Message, **kwargs):
|
||||
if used_quota is not None:
|
||||
match provider_configuration.system_configuration.current_quota_type:
|
||||
case ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
_deduct_credit_pool_quota_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="trial",
|
||||
)
|
||||
case ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
_deduct_credit_pool_quota_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
@ -200,6 +196,26 @@ def handle(sender: Message, **kwargs):
|
||||
raise
|
||||
|
||||
|
||||
def _deduct_credit_pool_quota_capped(*, tenant_id: str, credits_required: int, pool_type: str) -> None:
|
||||
"""Apply post-generation credit accounting without failing message persistence on quota exhaustion."""
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=credits_required,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
if deducted_credits < credits_required:
|
||||
logger.warning(
|
||||
"Credit pool exhausted during message-created accounting, "
|
||||
"tenant_id=%s, pool_type=%s, credits_required=%s, credits_deducted=%s",
|
||||
tenant_id,
|
||||
pool_type,
|
||||
credits_required,
|
||||
deducted_credits,
|
||||
)
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> int | None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
@ -13,6 +13,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditPoolService:
|
||||
@staticmethod
|
||||
def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None:
|
||||
return session.scalar(
|
||||
select(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""create default credit pool for new tenant"""
|
||||
@ -59,35 +71,57 @@ class CreditPoolService:
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
) -> int:
|
||||
"""Deduct credits, depleting the pool before raising if the balance is insufficient."""
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
if pool.remaining_credits <= 0:
|
||||
raise QuotaExceededError("No credits remaining")
|
||||
|
||||
remaining_credits = pool.remaining_credits
|
||||
actual_credits = min(credits_required, remaining_credits)
|
||||
quota_exceeded = actual_credits < credits_required
|
||||
"""Deduct exactly the requested credits or raise without mutating the pool."""
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
|
||||
)
|
||||
session.execute(stmt)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
remaining_credits = pool.remaining_credits
|
||||
if remaining_credits <= 0:
|
||||
raise QuotaExceededError("No credits remaining")
|
||||
if remaining_credits < credits_required:
|
||||
raise QuotaExceededError("Insufficient credits remaining")
|
||||
|
||||
pool.quota_used += credits_required
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
if quota_exceeded:
|
||||
raise QuotaExceededError("Insufficient credits remaining")
|
||||
return credits_required
|
||||
|
||||
return actual_credits
|
||||
@classmethod
|
||||
def deduct_credits_capped(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
) -> int:
|
||||
"""Deduct up to the available balance and return the actual deducted credits."""
|
||||
if credits_required <= 0:
|
||||
return 0
|
||||
|
||||
try:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type)
|
||||
if not pool:
|
||||
logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type)
|
||||
return 0
|
||||
|
||||
deducted_credits = min(credits_required, pool.remaining_credits)
|
||||
if deducted_credits <= 0:
|
||||
return 0
|
||||
|
||||
pool.quota_used += deducted_credits
|
||||
return deducted_credits
|
||||
except QuotaExceededError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct capped credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
@ -90,11 +90,14 @@ class TestCreditPoolService:
|
||||
pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert pool.quota_used == credits_required
|
||||
|
||||
def test_check_and_deduct_credits_depletes_and_raises_when_insufficient(self, db_session_with_containers: Session):
|
||||
def test_check_and_deduct_credits_raises_without_deducting_when_insufficient(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_used = pool.quota_used
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"):
|
||||
@ -102,4 +105,19 @@ class TestCreditPoolService:
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == pool.quota_limit
|
||||
assert updated_pool.quota_used == quota_used
|
||||
|
||||
def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session):
|
||||
tenant_id = self._create_tenant_id()
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
remaining = 5
|
||||
pool.quota_used = pool.quota_limit - remaining
|
||||
quota_limit = pool.quota_limit
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200)
|
||||
|
||||
assert result == remaining
|
||||
db_session_with_containers.expire_all()
|
||||
updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id)
|
||||
assert updated_pool.quota_used == quota_limit
|
||||
|
||||
@ -0,0 +1,68 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from events.event_handlers import update_provider_when_message_created
|
||||
from models import TenantCreditPool
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
pool_id = str(uuid4())
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
TenantCreditPool.__table__.insert(),
|
||||
{
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": 10,
|
||||
"quota_used": 9,
|
||||
},
|
||||
)
|
||||
|
||||
system_configuration = SimpleNamespace(
|
||||
current_quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_configurations=[
|
||||
SimpleNamespace(
|
||||
quota_type=ProviderQuotaType.TRIAL,
|
||||
quota_unit=QuotaUnit.TOKENS,
|
||||
quota_limit=10,
|
||||
)
|
||||
],
|
||||
)
|
||||
application_generate_entity = ChatAppGenerateEntity.model_construct(
|
||||
app_config=SimpleNamespace(tenant_id=tenant_id),
|
||||
model_conf=SimpleNamespace(
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
provider_model_bundle=SimpleNamespace(
|
||||
configuration=SimpleNamespace(
|
||||
using_provider_type=ProviderType.SYSTEM,
|
||||
system_configuration=system_configuration,
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
message = SimpleNamespace(message_tokens=2, answer_tokens=1)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
patch.object(update_provider_when_message_created, "_execute_provider_updates"),
|
||||
):
|
||||
update_provider_when_message_created.handle(
|
||||
sender=message,
|
||||
application_generate_entity=application_generate_entity,
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
assert quota_used == 10
|
||||
@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
@ -11,7 +12,7 @@ from models.enums import ProviderQuotaType
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -> None:
|
||||
def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TenantCreditPool.__table__.create(engine)
|
||||
tenant_id = str(uuid4())
|
||||
@ -23,10 +24,30 @@ def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -
|
||||
"id": pool_id,
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": ProviderQuotaType.TRIAL,
|
||||
"quota_limit": 10,
|
||||
"quota_used": 9,
|
||||
"quota_limit": quota_limit,
|
||||
"quota_used": quota_used,
|
||||
},
|
||||
)
|
||||
return engine, tenant_id, pool_id
|
||||
|
||||
|
||||
def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None:
|
||||
with engine.connect() as connection:
|
||||
return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 3
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 5
|
||||
|
||||
|
||||
def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with (
|
||||
patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)),
|
||||
@ -34,7 +55,14 @@ def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -
|
||||
):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
with engine.connect() as connection:
|
||||
quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id))
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 9
|
||||
|
||||
assert quota_used == 10
|
||||
|
||||
def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None:
|
||||
engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9)
|
||||
|
||||
with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)):
|
||||
deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3)
|
||||
|
||||
assert deducted_credits == 1
|
||||
assert _get_quota_used(engine=engine, pool_id=pool_id) == 10
|
||||
|
||||
Loading…
Reference in New Issue
Block a user