diff --git a/api/core/app/llm/__init__.py b/api/core/app/llm/__init__.py index d20a5b2344..85f342de5d 100644 --- a/api/core/app/llm/__init__.py +++ b/api/core/app/llm/__init__.py @@ -3,13 +3,17 @@ from .quota import ( deduct_llm_quota, deduct_llm_quota_for_model, + deduct_model_quota, ensure_llm_quota_available, ensure_llm_quota_available_for_model, + ensure_model_quota_available, ) __all__ = [ "deduct_llm_quota", "deduct_llm_quota_for_model", + "deduct_model_quota", "ensure_llm_quota_available", "ensure_llm_quota_available_for_model", + "ensure_model_quota_available", ] diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 723b090de9..b66749a467 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,8 +1,8 @@ -"""Tenant-scoped helpers for checking and deducting LLM provider quota. +"""Tenant-scoped helpers for checking and deducting provider model quota. -Workflow callers now bill quota from public model identity instead of passing a -fully prepared ``ModelInstance``. Keep the model-instance helpers as thin, -deprecated adapters so non-workflow code can move independently. +The public billing identity is ``tenant_id + provider + model_type + model``. +LLM callers still use thin adapters that compute quota usage from ``LLMUsage`` +so the workflow layer does not need to know generic billing details. """ import warnings @@ -33,28 +33,22 @@ def _get_provider_configuration(*, tenant_id: str, provider: str): return provider_configuration -def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None: +def ensure_model_quota_available(*, tenant_id: str, provider: str, model_type: ModelType, model: str) -> None: """Raise when a tenant-bound system provider model is already out of quota.""" provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) if provider_configuration.using_provider_type != ProviderType.SYSTEM: return provider_model = provider_configuration.get_provider_model( - model_type=ModelType.LLM, + model_type=model_type, model=model, ) if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED: raise QuotaExceededError(f"Model provider {provider} quota exceeded.") -def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None: - """Deduct tenant-bound quota for the resolved LLM model identity.""" - provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - +def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None: + """Compute the quota impact for an LLM invocation under the current quota mode.""" quota_unit = None for quota_configuration in system_configuration.quota_configurations: if quota_configuration.quota_type == system_configuration.current_quota_type: @@ -74,6 +68,21 @@ def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usa else: used_quota = 1 + return used_quota + + +def _deduct_model_quota_with_configuration( + *, + tenant_id: str, + provider: str, + provider_configuration, + used_quota: int | None, +) -> None: + """Apply a resolved quota charge against the current provider quota bucket.""" + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration if used_quota is not None and system_configuration.current_quota_type is not None: match system_configuration.current_quota_type: case ProviderQuotaType.TRIAL: @@ -111,6 +120,53 @@ def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usa session.execute(stmt) +def deduct_model_quota( + *, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + used_quota: int | None, +) -> None: + """Deduct quota for the resolved tenant/provider/model identity.""" + _ = model_type + _ = model + provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) + _deduct_model_quota_with_configuration( + tenant_id=tenant_id, + provider=provider, + provider_configuration=provider_configuration, + used_quota=used_quota, + ) + + +def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None: + """Raise when a tenant-bound LLM model is already out of quota.""" + ensure_model_quota_available( + tenant_id=tenant_id, + provider=provider, + model_type=ModelType.LLM, + model=model, + ) + + +def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None: + """Deduct tenant-bound quota for the resolved LLM model identity.""" + provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) + used_quota = _resolve_llm_used_quota( + system_configuration=provider_configuration.system_configuration, + model=model, + usage=usage, + ) + deduct_model_quota( + tenant_id=tenant_id, + provider=provider, + model_type=ModelType.LLM, + model=model, + used_quota=used_quota, + ) + + def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None: """Deprecated compatibility wrapper for callers that still pass ModelInstance.""" warnings.warn( @@ -119,9 +175,10 @@ def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None: DeprecationWarning, stacklevel=2, ) - ensure_llm_quota_available_for_model( + ensure_model_quota_available( tenant_id=model_instance.provider_model_bundle.configuration.tenant_id, provider=model_instance.provider, + model_type=model_instance.model_type_instance.model_type, model=model_instance.model_name, ) @@ -134,9 +191,14 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL DeprecationWarning, stacklevel=2, ) - deduct_llm_quota_for_model( + deduct_model_quota( tenant_id=tenant_id, provider=model_instance.provider, + model_type=model_instance.model_type_instance.model_type, model=model_instance.model_name, - usage=usage, + used_quota=_resolve_llm_used_quota( + system_configuration=model_instance.provider_model_bundle.configuration.system_configuration, + model=model_instance.model_name, + usage=usage, + ), ) diff --git a/api/tests/unit_tests/core/app/test_llm_quota.py b/api/tests/unit_tests/core/app/test_llm_quota.py index b94e3aa758..1ba8359b46 100644 --- a/api/tests/unit_tests/core/app/test_llm_quota.py +++ b/api/tests/unit_tests/core/app/test_llm_quota.py @@ -6,17 +6,20 @@ import pytest from core.app.llm.quota import ( deduct_llm_quota, deduct_llm_quota_for_model, + deduct_model_quota, ensure_llm_quota_available, ensure_llm_quota_available_for_model, + ensure_model_quota_available, ) from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType -def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None: +def test_ensure_model_quota_available_raises_when_system_model_is_exhausted() -> None: provider_configuration = SimpleNamespace( using_provider_type=ProviderType.SYSTEM, get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)), @@ -28,17 +31,37 @@ def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhaus patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."), ): + ensure_model_quota_available( + tenant_id="tenant-id", + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="gpt-4o", + ) + + provider_configuration.get_provider_model.assert_called_once_with( + model_type=ModelType.TEXT_EMBEDDING, + model="gpt-4o", + ) + + +def test_ensure_llm_quota_available_for_model_delegates_with_llm_model_type() -> None: + with patch("core.app.llm.quota.ensure_model_quota_available") as mock_ensure: ensure_llm_quota_available_for_model( tenant_id="tenant-id", provider="openai", model="gpt-4o", ) + mock_ensure.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o", + ) -def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None: - usage = LLMUsage.empty_usage() - usage.total_tokens = 42 +def test_deduct_model_quota_uses_identity_based_trial_billing() -> None: + provider_configuration = SimpleNamespace( using_provider_type=ProviderType.SYSTEM, system_configuration=SimpleNamespace( @@ -59,11 +82,12 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None: patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits, ): - deduct_llm_quota_for_model( + deduct_model_quota( tenant_id="tenant-id", provider="openai", + model_type=ModelType.TEXT_EMBEDDING, model="gpt-4o", - usage=usage, + used_quota=42, ) mock_deduct_credits.assert_called_once_with( @@ -72,36 +96,90 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None: ) -def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None: +def test_deduct_llm_quota_for_model_delegates_with_llm_model_type_and_usage() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 42 + provider_configuration = SimpleNamespace( + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + + with ( + patch("core.app.llm.quota._get_provider_configuration", return_value=provider_configuration), + patch("core.app.llm.quota.deduct_model_quota") as mock_deduct, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o", + used_quota=42, + ) + + +def test_ensure_llm_quota_available_wrapper_warns_and_delegates_with_model_type() -> None: model_instance = SimpleNamespace( provider="openai", model_name="gpt-4o", provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")), + model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING), ) with ( pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"), - patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure, + patch("core.app.llm.quota.ensure_model_quota_available") as mock_ensure, ): ensure_llm_quota_available(model_instance=model_instance) mock_ensure.assert_called_once_with( tenant_id="tenant-id", provider="openai", + model_type=ModelType.TEXT_EMBEDDING, model="gpt-4o", ) -def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None: +def test_deduct_llm_quota_wrapper_warns_and_delegates_with_model_type() -> None: usage = LLMUsage.empty_usage() + usage.total_tokens = 7 model_instance = SimpleNamespace( provider="openai", model_name="gpt-4o", + model_type_instance=SimpleNamespace(model_type=ModelType.LLM), + provider_model_bundle=SimpleNamespace( + configuration=SimpleNamespace( + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ) + ) + ), ) with ( pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"), - patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct, + patch("core.app.llm.quota.deduct_model_quota") as mock_deduct, ): deduct_llm_quota( tenant_id="tenant-id", @@ -112,6 +190,7 @@ def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", provider="openai", + model_type=ModelType.LLM, model="gpt-4o", - usage=usage, + used_quota=7, )