diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index ada7f41ec2..1954602571 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -34,35 +34,52 @@ class CreditPoolService: .first() ) + @classmethod + def check_credits_available( + cls, + tenant_id: str, + credits_required: int, + pool_type: str = "trial", + ) -> bool: + """check if credits are available without deducting""" + pool = cls.get_pool(tenant_id, pool_type) + if not pool: + return False + return pool.remaining_credits >= credits_required + @classmethod def check_and_deduct_credits( cls, tenant_id: str, credits_required: int, pool_type: str = "trial", - ): - """check and deduct credits""" + ) -> int: + """check and deduct credits, returns actual credits deducted""" pool = cls.get_pool(tenant_id, pool_type) if not pool: raise QuotaExceededError("Credit pool not found") - if pool.remaining_credits < credits_required: - raise QuotaExceededError( - f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}" - ) + if pool.remaining_credits <= 0: + raise QuotaExceededError("No credits remaining") + + # deduct all remaining credits if less than required + actual_credits = min(credits_required, pool.remaining_credits) + try: with Session(db.engine) as session: - update_values = {"quota_used": pool.quota_used + credits_required} - - where_conditions = [ - TenantCreditPool.pool_type == pool_type, - TenantCreditPool.tenant_id == tenant_id, - TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit, - ] - stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values) + stmt = ( + update(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, + ) + .values(quota_used=TenantCreditPool.quota_used + actual_credits) + ) session.execute(stmt) session.commit() except Exception: logger.exception("Failed to deduct credits for tenant %s", tenant_id) raise QuotaExceededError("Failed to deduct credits") + + return actual_credits