diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 63d2235358..182f1b767d 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL # TODO: Use provider name with prefix after the data migration. Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_type == system_configuration.current_quota_type, Provider.quota_limit > Provider.quota_used, ) .values( diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 30933239f6..b2a8e9c114 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -626,9 +626,8 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM: continue - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( - provider_record - ) + if provider_record.quota_type is not None: + provider_quota_to_provider_record_dict[provider_record.quota_type] = provider_record for quota in configuration.quotas: if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID): @@ -641,7 +640,7 @@ class ProviderManager: # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, provider_type=ProviderType.SYSTEM, - quota_type=quota.quota_type, + quota_type=quota.quota_type, # type: ignore[arg-type] quota_limit=0, # type: ignore quota_used=0, is_valid=True, @@ -921,9 +920,8 @@ class ProviderManager: if provider_record.provider_type != ProviderType.SYSTEM: continue - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( - provider_record - ) + if provider_record.quota_type is not None: + quota_type_to_provider_records_dict[provider_record.quota_type] = provider_record # type: ignore[index] quota_configurations = [] if dify_config.EDITION == "CLOUD": diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 1ddcc8f792..f68cdaadde 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -157,7 +157,7 @@ def handle(sender: Message, **kwargs): tenant_id=tenant_id, provider_name=ModelProviderID(model_config.provider).provider_name, provider_type=ProviderType.SYSTEM.value, - quota_type=provider_configuration.system_configuration.current_quota_type.value, + quota_type=provider_configuration.system_configuration.current_quota_type, ), values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), additional_filters=_ProviderUpdateAdditionalFilters( diff --git a/api/models/provider.py b/api/models/provider.py index afeee20b1e..bdcfb7aa0d 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -13,7 +13,7 @@ from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db -from .enums import CredentialSourceType, PaymentStatus +from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType from .types import EnumText, LongText, StringUUID @@ -29,24 +29,6 @@ class ProviderType(StrEnum): raise ValueError(f"No matching enum found for value '{value}'") -class ProviderQuotaType(StrEnum): - PAID = auto() - """hosted paid quota""" - - FREE = auto() - """third-party free quota""" - - TRIAL = auto() - """hosted trial quota""" - - @staticmethod - def value_of(value: str) -> ProviderQuotaType: - for member in ProviderQuotaType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - class Provider(TypeBase): """ Provider model representing the API providers and their configurations. @@ -77,7 +59,9 @@ class Provider(TypeBase): last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) - quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="") + quota_type: Mapped[ProviderQuotaType | None] = mapped_column( + EnumText(ProviderQuotaType, length=40), nullable=True, server_default=text("''"), default=None + ) quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None) quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0) diff --git a/api/models/types.py b/api/models/types.py index f8369dab9e..98084563be 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -144,8 +144,8 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]): return dialect.type_descriptor(VARCHAR(self._length)) def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None: - if value is None: - return value + if value is None or value == "": + return None # Type annotation guarantees value is str at this point return self._enum_class(value) diff --git a/api/tests/unit_tests/models/test_provider_models.py b/api/tests/unit_tests/models/test_provider_models.py index f628e54a4d..d7b597e5fb 100644 --- a/api/tests/unit_tests/models/test_provider_models.py +++ b/api/tests/unit_tests/models/test_provider_models.py @@ -202,7 +202,7 @@ class TestProviderModel: # Assert assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False - assert provider.quota_type == "" + assert provider.quota_type is None assert provider.quota_limit is None assert provider.quota_used == 0 assert provider.credential_id is None