dify/api/core/app/llm/quota.py
-LAN- 19476109da
chore(api): upgrade graphon to v0.3.0 (#35469)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-05-09 07:30:03 +00:00

201 lines
7.9 KiB
Python

"""Tenant-scoped helpers for checking and deducting LLM provider quota.
System-hosted quota accounting is currently defined only for LLM models. Keep
the public helpers LLM-specific so callers do not carry unused model-type
plumbing, and fail loudly if the deprecated ``ModelInstance`` wrappers are used
with a non-LLM model.
"""
import warnings
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.model_runtime.entities.model_entities import ModelType
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def _get_provider_configuration(*, tenant_id: str, provider: str):
"""Resolve the tenant-bound provider configuration for quota decisions."""
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
provider_configuration = provider_manager.get_configurations(tenant_id).get(provider)
if provider_configuration is None:
raise ValueError(f"Provider {provider} does not exist.")
return provider_configuration
def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None:
"""Raise when a tenant-bound LLM model is already out of quota."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=ModelType.LLM,
model=model,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None:
"""Compute the quota impact for an LLM invocation under the current quota mode."""
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return None
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model)
else:
used_quota = 1
return used_quota
def _deduct_free_llm_quota(
*,
tenant_id: str,
provider: str,
quota_type: ProviderQuotaType,
used_quota: int,
) -> None:
"""Deduct FREE provider quota, capping at the limit before reporting exhaustion."""
quota_exceeded = False
with sessionmaker(bind=db.engine).begin() as session:
provider_record = session.scalar(
select(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota_type,
)
.with_for_update()
)
if (
provider_record is None
or provider_record.quota_limit is None
or provider_record.quota_used is None
or provider_record.quota_limit <= provider_record.quota_used
):
quota_exceeded = True
else:
available_quota = provider_record.quota_limit - provider_record.quota_used
deducted_quota = min(used_quota, available_quota)
provider_record.quota_used += deducted_quota
provider_record.last_used = naive_utc_now()
quota_exceeded = deducted_quota < used_quota
if quota_exceeded:
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None:
"""Apply a resolved LLM quota charge against the current provider quota bucket."""
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
if used_quota is not None and system_configuration.current_quota_type is not None:
match system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.deduct_credits_capped(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
case ProviderQuotaType.FREE:
_deduct_free_llm_quota(
tenant_id=tenant_id,
provider=provider,
quota_type=system_configuration.current_quota_type,
used_quota=used_quota,
)
case _:
return
def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None:
"""Deduct tenant-bound quota for the resolved LLM model identity."""
provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider)
used_quota = _resolve_llm_used_quota(
system_configuration=provider_configuration.system_configuration,
model=model,
usage=usage,
)
_deduct_used_llm_quota(
tenant_id=tenant_id,
provider=provider,
provider_configuration=provider_configuration,
used_quota=used_quota,
)
def _require_llm_model_instance(model_instance: ModelInstance) -> None:
"""Reject deprecated wrapper calls that pass a non-LLM model instance."""
if model_instance.model_type_instance.model_type != ModelType.LLM:
raise ValueError("LLM quota helpers only support LLM model instances.")
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"ensure_llm_quota_available(model_instance=...) is deprecated; "
"use ensure_llm_quota_available_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
_require_llm_model_instance(model_instance)
ensure_llm_quota_available_for_model(
tenant_id=model_instance.provider_model_bundle.configuration.tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
)
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
"""Deprecated compatibility wrapper for callers that still pass ModelInstance."""
warnings.warn(
"deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; "
"use deduct_llm_quota_for_model(...) instead.",
DeprecationWarning,
stacklevel=2,
)
_require_llm_model_instance(model_instance)
deduct_llm_quota_for_model(
tenant_id=tenant_id,
provider=model_instance.provider,
model=model_instance.model_name,
usage=usage,
)