diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 269cc77158..2422eed5a7 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -132,6 +132,7 @@ class LLMQuotaLayer(GraphEngineLayer): error_type=error_type, ) + # TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override. node._run = quota_aborted_run # type: ignore[method-assign] self._send_abort_command(reason=reason) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 1d615f0f87..8dec5876a9 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -137,17 +137,13 @@ def handle(sender: Message, **kwargs): if used_quota is not None: match provider_configuration.system_configuration.current_quota_type: case ProviderQuotaType.TRIAL: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( + _deduct_credit_pool_quota_capped( tenant_id=tenant_id, credits_required=used_quota, pool_type="trial", ) case ProviderQuotaType.PAID: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( + _deduct_credit_pool_quota_capped( tenant_id=tenant_id, credits_required=used_quota, pool_type="paid", @@ -200,6 +196,26 @@ def handle(sender: Message, **kwargs): raise +def _deduct_credit_pool_quota_capped(*, tenant_id: str, credits_required: int, pool_type: str) -> None: + """Apply post-generation credit accounting without failing message persistence on quota exhaustion.""" + from services.credit_pool_service import CreditPoolService + + deducted_credits = CreditPoolService.deduct_credits_capped( + tenant_id=tenant_id, + credits_required=credits_required, + pool_type=pool_type, + ) + if deducted_credits < credits_required: + logger.warning( + "Credit pool exhausted during message-created accounting, " + "tenant_id=%s, pool_type=%s, credits_required=%s, credits_deducted=%s", + tenant_id, + pool_type, + credits_required, + deducted_credits, + ) + + def _calculate_quota_usage( *, message: Message, system_configuration: SystemConfiguration, model_name: str ) -> int | None: diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 9d57cc0ac9..1f419d7a5b 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,7 +1,7 @@ import logging -from sqlalchemy import select, update -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.errors.error import QuotaExceededError @@ -13,6 +13,18 @@ logger = logging.getLogger(__name__) class CreditPoolService: + @staticmethod + def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None: + return session.scalar( + select(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, + ) + .limit(1) + .with_for_update() + ) + @classmethod def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" @@ -59,35 +71,57 @@ class CreditPoolService: credits_required: int, pool_type: str = "trial", ) -> int: - """Deduct credits, depleting the pool before raising if the balance is insufficient.""" - - pool = cls.get_pool(tenant_id, pool_type) - if not pool: - raise QuotaExceededError("Credit pool not found") - - if pool.remaining_credits <= 0: - raise QuotaExceededError("No credits remaining") - - remaining_credits = pool.remaining_credits - actual_credits = min(credits_required, remaining_credits) - quota_exceeded = actual_credits < credits_required + """Deduct exactly the requested credits or raise without mutating the pool.""" + if credits_required <= 0: + return 0 try: - with sessionmaker(db.engine).begin() as session: - stmt = ( - update(TenantCreditPool) - .where( - TenantCreditPool.tenant_id == tenant_id, - TenantCreditPool.pool_type == pool_type, - ) - .values(quota_used=TenantCreditPool.quota_used + actual_credits) - ) - session.execute(stmt) + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type) + if not pool: + raise QuotaExceededError("Credit pool not found") + + remaining_credits = pool.remaining_credits + if remaining_credits <= 0: + raise QuotaExceededError("No credits remaining") + if remaining_credits < credits_required: + raise QuotaExceededError("Insufficient credits remaining") + + pool.quota_used += credits_required + except QuotaExceededError: + raise except Exception: logger.exception("Failed to deduct credits for tenant %s", tenant_id) raise QuotaExceededError("Failed to deduct credits") - if quota_exceeded: - raise QuotaExceededError("Insufficient credits remaining") + return credits_required - return actual_credits + @classmethod + def deduct_credits_capped( + cls, + tenant_id: str, + credits_required: int, + pool_type: str = "trial", + ) -> int: + """Deduct up to the available balance and return the actual deducted credits.""" + if credits_required <= 0: + return 0 + + try: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type) + if not pool: + logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type) + return 0 + + deducted_credits = min(credits_required, pool.remaining_credits) + if deducted_credits <= 0: + return 0 + + pool.quota_used += deducted_credits + return deducted_credits + except QuotaExceededError: + raise + except Exception: + logger.exception("Failed to deduct capped credits for tenant %s", tenant_id) + raise QuotaExceededError("Failed to deduct credits") diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 45e75f4dc1..07dc3a4e9e 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -90,11 +90,14 @@ class TestCreditPoolService: pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert pool.quota_used == credits_required - def test_check_and_deduct_credits_depletes_and_raises_when_insufficient(self, db_session_with_containers: Session): + def test_check_and_deduct_credits_raises_without_deducting_when_insufficient( + self, db_session_with_containers: Session + ): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) remaining = 5 pool.quota_used = pool.quota_limit - remaining + quota_used = pool.quota_used db_session_with_containers.commit() with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"): @@ -102,4 +105,19 @@ class TestCreditPoolService: db_session_with_containers.expire_all() updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) - assert updated_pool.quota_used == pool.quota_limit + assert updated_pool.quota_used == quota_used + + def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + remaining = 5 + pool.quota_used = pool.quota_limit - remaining + quota_limit = pool.quota_limit + db_session_with_containers.commit() + + result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200) + + assert result == remaining + db_session_with_containers.expire_all() + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool.quota_used == quota_limit diff --git a/api/tests/unit_tests/events/test_update_provider_when_message_created.py b/api/tests/unit_tests/events/test_update_provider_when_message_created.py new file mode 100644 index 0000000000..6bf697b747 --- /dev/null +++ b/api/tests/unit_tests/events/test_update_provider_when_message_created.py @@ -0,0 +1,68 @@ +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +from sqlalchemy import create_engine, select + +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit +from events.event_handlers import update_provider_when_message_created +from models import TenantCreditPool +from models.provider import ProviderType + + +def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None: + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + tenant_id = str(uuid4()) + pool_id = str(uuid4()) + with engine.begin() as connection: + connection.execute( + TenantCreditPool.__table__.insert(), + { + "id": pool_id, + "tenant_id": tenant_id, + "pool_type": ProviderQuotaType.TRIAL, + "quota_limit": 10, + "quota_used": 9, + }, + ) + + system_configuration = SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=10, + ) + ], + ) + application_generate_entity = ChatAppGenerateEntity.model_construct( + app_config=SimpleNamespace(tenant_id=tenant_id), + model_conf=SimpleNamespace( + provider="openai", + model="gpt-4o", + provider_model_bundle=SimpleNamespace( + configuration=SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + ) + ), + ), + ) + message = SimpleNamespace(message_tokens=2, answer_tokens=1) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + patch.object(update_provider_when_message_created, "_execute_provider_updates"), + ): + update_provider_when_message_created.handle( + sender=message, + application_generate_entity=application_generate_entity, + ) + + with engine.connect() as connection: + quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) + + assert quota_used == 10 diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py index 8d30f4056d..7689282e8b 100644 --- a/api/tests/unit_tests/services/test_credit_pool_service.py +++ b/api/tests/unit_tests/services/test_credit_pool_service.py @@ -4,6 +4,7 @@ from uuid import uuid4 import pytest from sqlalchemy import create_engine, select +from sqlalchemy.engine import Engine from core.errors.error import QuotaExceededError from models import TenantCreditPool @@ -11,7 +12,7 @@ from models.enums import ProviderQuotaType from services.credit_pool_service import CreditPoolService -def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() -> None: +def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]: engine = create_engine("sqlite:///:memory:") TenantCreditPool.__table__.create(engine) tenant_id = str(uuid4()) @@ -23,10 +24,30 @@ def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() - "id": pool_id, "tenant_id": tenant_id, "pool_type": ProviderQuotaType.TRIAL, - "quota_limit": 10, - "quota_used": 9, + "quota_limit": quota_limit, + "quota_used": quota_used, }, ) + return engine, tenant_id, pool_id + + +def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None: + with engine.connect() as connection: + return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) + + +def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) + + with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) + + assert deducted_credits == 3 + assert _get_quota_used(engine=engine, pool_id=pool_id) == 5 + + +def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) with ( patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), @@ -34,7 +55,14 @@ def test_check_and_deduct_credits_depletes_pool_and_raises_when_insufficient() - ): CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) - with engine.connect() as connection: - quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) + assert _get_quota_used(engine=engine, pool_id=pool_id) == 9 - assert quota_used == 10 + +def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) + + with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3) + + assert deducted_credits == 1 + assert _get_quota_used(engine=engine, pool_id=pool_id) == 10