diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 489af29460..3ecd646862 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -618,9 +618,9 @@ class ProviderManager: ) for quota in configuration.quotas: - if quota.quota_type == ProviderQuotaType.TRIAL: + if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID): # Init trial provider records if not exists - if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: + if quota.quota_type not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic new_provider_record = Provider( @@ -628,7 +628,7 @@ class ProviderManager: # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.TRIAL.value, + quota_type=quota.quota_type, quota_limit=0, # type: ignore quota_used=0, is_valid=True, @@ -642,7 +642,7 @@ class ProviderManager: Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value, + Provider.quota_type == quota.quota_type, ) existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: @@ -652,7 +652,7 @@ class ProviderManager: existed_provider_record.is_valid = True db.session.commit() - provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) + provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) return provider_name_to_provider_records_dict @@ -918,11 +918,11 @@ class ProviderManager: trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type="trial", + pool_type=ProviderQuotaType.TRIAL.value, ) paid_pool = CreditPoolService.get_pool( tenant_id=tenant_id, - pool_type="paid", + pool_type=ProviderQuotaType.PAID.value, ) else: trail_pool = None