refactor(api): include model type in quota identity

Add model-type-aware quota helpers at the shared billing boundary while keeping the LLM-specific helpers as thin adapters.

Preserve model_type in the deprecated ModelInstance wrappers and extend the quota unit tests to cover the generic helper delegation path.
This commit is contained in:
-LAN- 2026-04-22 14:19:43 +08:00
parent 9b65e53c12
commit dc55adf9ae
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
3 changed files with 173 additions and 28 deletions

View File

@ -3,13 +3,17 @@
from .quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
deduct_model_quota,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
ensure_model_quota_available,
)
__all__ = [
"deduct_llm_quota",
"deduct_llm_quota_for_model",
"deduct_model_quota",
"ensure_llm_quota_available",
"ensure_llm_quota_available_for_model",
"ensure_model_quota_available",
]

View File

@ -1,8 +1,8 @@
"""Tenant-scoped helpers for checking and deducting LLM provider quota.
"""Tenant-scoped helpers for checking and deducting provider model quota.
Workflow callers now bill quota from public model identity instead of passing a
fully prepared ``ModelInstance``. Keep the model-instance helpers as thin,
deprecated adapters so non-workflow code can move independently.
The public billing identity is ``tenant_id + provider + model_type + model``.
LLM callers still use thin adapters that compute quota usage from ``LLMUsage``
so the workflow layer does not need to know generic billing details.
"""
import warnings
@ -33,28 +33,22 @@ def _get_provider_configuration(*, tenant_id: str, provider: str):
return provider_configuration
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
def ensure_model_quota_available(*, tenant_id: str, provider: str, model_type: ModelType, model: str) -> None:
"""Raise when a tenant-bound system provider model is already out of quota."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=ModelType.LLM,
model_type=model_type,
model=model,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
"""Deduct tenant-bound quota for the resolved LLM model identity."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None:
"""Compute the quota impact for an LLM invocation under the current quota mode."""
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
@ -74,6 +68,21 @@ def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usa
else:
used_quota = 1
return used_quota
def _deduct_model_quota_with_configuration(
*,
tenant_id: str,
provider: str,
provider_configuration,
used_quota: int | None,
) -> None:
"""Apply a resolved quota charge against the current provider quota bucket."""
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
if used_quota is not None and system_configuration.current_quota_type is not None:
match system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
@ -111,6 +120,53 @@ def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usa
session.execute(stmt)
def deduct_model_quota(
*,
tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
used_quota: int | None,
) -> None:
"""Deduct quota for the resolved tenant/provider/model identity."""
_ = model_type
_ = model
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
_deduct_model_quota_with_configuration(
tenant_id=tenant_id,
provider=provider,
provider_configuration=provider_configuration,
used_quota=used_quota,
)
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
"""Raise when a tenant-bound LLM model is already out of quota."""
ensure_model_quota_available(
tenant_id=tenant_id,
provider=provider,
model_type=ModelType.LLM,
model=model,
)
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
"""Deduct tenant-bound quota for the resolved LLM model identity."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
used_quota = _resolve_llm_used_quota(
system_configuration=provider_configuration.system_configuration,
model=model,
usage=usage,
)
deduct_model_quota(
tenant_id=tenant_id,
provider=provider,
model_type=ModelType.LLM,
model=model,
used_quota=used_quota,
)
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
@ -119,9 +175,10 @@ def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
DeprecationWarning,
stacklevel=2,
)
ensure_llm_quota_available_for_model(
ensure_model_quota_available(
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
provider=model_instance.provider,
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
@ -134,9 +191,14 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
DeprecationWarning,
stacklevel=2,
)
deduct_llm_quota_for_model(
deduct_model_quota(
tenant_id=tenant_id,
provider=model_instance.provider,
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
usage=usage,
used_quota=_resolve_llm_used_quota(
system_configuration=model_instance.provider_model_bundle.configuration.system_configuration,
model=model_instance.model_name,
usage=usage,
),
)

View File

@ -6,17 +6,20 @@ import pytest
from core.app.llm.quota import (
deduct_llm_quota,
deduct_llm_quota_for_model,
deduct_model_quota,
ensure_llm_quota_available,
ensure_llm_quota_available_for_model,
ensure_model_quota_available,
)
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from models.provider import ProviderType
def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None:
def test_ensure_model_quota_available_raises_when_system_model_is_exhausted() -> None:
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)),
@ -28,17 +31,37 @@ def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhaus
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."),
):
ensure_model_quota_available(
tenant_id="tenant-id",
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="gpt-4o",
)
provider_configuration.get_provider_model.assert_called_once_with(
model_type=ModelType.TEXT_EMBEDDING,
model="gpt-4o",
)
def test_ensure_llm_quota_available_for_model_delegates_with_llm_model_type() -> None:
with patch("core.app.llm.quota.ensure_model_quota_available") as mock_ensure:
ensure_llm_quota_available_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
)
mock_ensure.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model_type=ModelType.LLM,
model="gpt-4o",
)
def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
def test_deduct_model_quota_uses_identity_based_trial_billing() -> None:
provider_configuration = SimpleNamespace(
using_provider_type=ProviderType.SYSTEM,
system_configuration=SimpleNamespace(
@ -59,11 +82,12 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager),
patch("services.credit_pool_service.CreditPoolService.check_and_deduct_credits") as mock_deduct_credits,
):
deduct_llm_quota_for_model(
deduct_model_quota(
tenant_id="tenant-id",
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="gpt-4o",
usage=usage,
used_quota=42,
)
mock_deduct_credits.assert_called_once_with(
@ -72,36 +96,90 @@ def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None:
)
def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None:
def test_deduct_llm_quota_for_model_delegates_with_llm_model_type_and_usage() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 42
provider_configuration = SimpleNamespace(
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
),
)
with (
patch("core.app.llm.quota._get_provider_configuration", return_value=provider_configuration),
patch("core.app.llm.quota.deduct_model_quota") as mock_deduct,
):
deduct_llm_quota_for_model(
tenant_id="tenant-id",
provider="openai",
model="gpt-4o",
usage=usage,
)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model_type=ModelType.LLM,
model="gpt-4o",
used_quota=42,
)
def test_ensure_llm_quota_available_wrapper_warns_and_delegates_with_model_type() -> None:
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")),
model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING),
)
with (
pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"),
patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure,
patch("core.app.llm.quota.ensure_model_quota_available") as mock_ensure,
):
ensure_llm_quota_available(model_instance=model_instance)
mock_ensure.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="gpt-4o",
)
def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None:
def test_deduct_llm_quota_wrapper_warns_and_delegates_with_model_type() -> None:
usage = LLMUsage.empty_usage()
usage.total_tokens = 7
model_instance = SimpleNamespace(
provider="openai",
model_name="gpt-4o",
model_type_instance=SimpleNamespace(model_type=ModelType.LLM),
provider_model_bundle=SimpleNamespace(
configuration=SimpleNamespace(
system_configuration=SimpleNamespace(
current_quota_type=ProviderQuotaType.TRIAL,
quota_configurations=[
SimpleNamespace(
quota_type=ProviderQuotaType.TRIAL,
quota_unit=QuotaUnit.TOKENS,
quota_limit=100,
)
],
)
)
),
)
with (
pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"),
patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct,
patch("core.app.llm.quota.deduct_model_quota") as mock_deduct,
):
deduct_llm_quota(
tenant_id="tenant-id",
@ -112,6 +190,7 @@ def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None:
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
provider="openai",
model_type=ModelType.LLM,
model="gpt-4o",
usage=usage,
used_quota=7,
)