diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 2adeb9f625..e82443af1f 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -600,9 +600,9 @@ class ProviderManager: ) for quota in configuration.quotas: - if quota.quota_type == ProviderQuotaType.TRIAL: + if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID): # Init trial provider records if not exists - if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: + if quota.quota_type not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic new_provider_record = Provider( @@ -610,7 +610,7 @@ class ProviderManager: # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.TRIAL.value, + quota_type=quota.quota_type, quota_limit=0, # type: ignore quota_used=0, is_valid=True, @@ -624,7 +624,7 @@ class ProviderManager: Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value, + Provider.quota_type == quota.quota_type, ) existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: @@ -634,7 +634,7 @@ class ProviderManager: existed_provider_record.is_valid = True db.session.commit() - provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) + provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) return provider_name_to_provider_records_dict @@ -900,11 +900,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type="trial", + pool_type=ProviderQuotaType.TRIAL.value, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type="paid", + pool_type=ProviderQuotaType.PAID.value, ) else: trail_pool = None