add paid credit

This commit is contained in:
Yansong Zhang 2025-09-26 12:49:26 +08:00
parent db0780cfa8
commit ab34cea714
8 changed files with 192 additions and 126 deletions

View File

@ -128,6 +128,16 @@ class HostedGeminiConfig(BaseSettings):
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
@ -159,6 +169,16 @@ class HostedXAIConfig(BaseSettings):
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
@ -190,6 +210,16 @@ class HostedDeepseekConfig(BaseSettings):
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedAzureOpenAiConfig(BaseSettings):
"""
@ -252,6 +282,16 @@ class HostedAnthropicConfig(BaseSettings):
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedMinmaxConfig(BaseSettings):

View File

@ -169,6 +169,11 @@ class HostingConfiguration:
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
@ -196,7 +201,8 @@ class HostingConfiguration:
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota()
quotas.append(paid_quota)
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
quotas.append(paid_quota,restrict_models=paid_models)
if len(quotas) > 0:
credentials = {
@ -223,6 +229,11 @@ class HostingConfiguration:
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY,
@ -248,6 +259,11 @@ class HostingConfiguration:
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,

View File

@ -916,11 +916,17 @@ class ProviderManager:
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
pool = CreditPoolService.get_or_create_pool(
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type="trial",
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type="paid",
)
else:
pool = None
trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
@ -942,13 +948,23 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and pool is not None:
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=pool.quota_used,
quota_limit=pool.quota_limit,
is_valid=pool.quota_limit > pool.quota_used or pool.quota_limit == -1,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)

View File

@ -23,7 +23,7 @@ from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from core.entities.provider_entities import ProviderQuotaType
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
session.execute(stmt)
session.commit()
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -133,18 +133,24 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
logger.info("used_quota: %s", used_quota)
if used_quota is not None:
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
logger.info("deduct credits")
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="trial",
)
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
logger.info("update provider quota")
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,

View File

@ -32,8 +32,8 @@ def upgrade():
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
# Data migration: Move trial quota data from providers to tenant_credit_pools
migrate_trial_quota_data()
# Data migration: Move quota data from providers to tenant_credit_pools
migrate_quota_data()
# ### end Alembic commands ###
@ -48,49 +48,57 @@ def downgrade():
# ### end Alembic commands ###
def migrate_trial_quota_data():
def migrate_quota_data():
"""
Migrate quota data from providers table to tenant_credit_pools table
for providers with quota_type='trial', provider_name='openai', provider_type='system'
for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
"""
# Create connection
bind = op.get_bind()
# Query providers that match the criteria
select_sql = sa.text("""
SELECT tenant_id, quota_limit, quota_used
FROM providers
WHERE quota_type = 'trial'
AND provider_name = 'openai'
AND provider_type = 'system'
AND quota_limit IS NOT NULL
""")
# Define quota type mappings
quota_type_mappings = ['trial', 'paid']
result = bind.execute(select_sql)
providers_data = result.fetchall()
# Insert data into tenant_credit_pools
for provider_data in providers_data:
tenant_id, quota_limit, quota_used = provider_data
# Check if credit pool already exists for this tenant
check_sql = sa.text("""
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = 'trial'
for quota_type in quota_type_mappings:
# Query providers that match the criteria
select_sql = sa.text("""
SELECT tenant_id, quota_limit, quota_used
FROM providers
WHERE quota_type = :quota_type
AND provider_name = 'openai'
AND provider_type = 'system'
AND quota_limit IS NOT NULL
""")
existing_count = bind.execute(check_sql, {"tenant_id": tenant_id}).scalar()
result = bind.execute(select_sql, {"quota_type": quota_type})
providers_data = result.fetchall()
if existing_count == 0:
# Insert new credit pool record
insert_sql = sa.text("""
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, 'trial', :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
# Insert data into tenant_credit_pools
for provider_data in providers_data:
tenant_id, quota_limit, quota_used = provider_data
# Check if credit pool already exists for this tenant and pool type
check_sql = sa.text("""
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
""")
bind.execute(insert_sql, {
existing_count = bind.execute(check_sql, {
"tenant_id": tenant_id,
"quota_limit": quota_limit or 0,
"quota_used": quota_used or 0
})
"pool_type": quota_type
}).scalar()
if existing_count == 0:
# Insert new credit pool record
insert_sql = sa.text("""
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""")
bind.execute(insert_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type,
"quota_limit": quota_limit or 0,
"quota_used": quota_used or 0
})

View File

@ -7,6 +7,7 @@ from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
@ -23,62 +24,27 @@ class CreditPoolService:
return credit_pool
@classmethod
def get_pool(cls, tenant_id: str) -> Optional[TenantCreditPool]:
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
"""get tenant credit pool"""
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
)
.first()
)
@classmethod
def get_or_create_pool(cls, tenant_id: str) -> TenantCreditPool:
"""get or create credit pool"""
# First try to get existing pool
pool = cls.get_pool(tenant_id)
if pool:
return pool
# Create new pool if not exists, handle race condition
try:
# Double-check in case another thread created it
pool = (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
)
.first()
)
if pool:
return pool
# Create new pool
pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
)
db.session.add(pool)
db.session.commit()
except Exception:
# If creation fails (e.g., due to race condition), rollback and try to get existing one
db.session.rollback()
pool = cls.get_pool(tenant_id)
if not pool:
raise
return pool
@classmethod
def check_and_deduct_credits(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
):
"""check and deduct credits"""
logger.info("check and deduct credits")
pool = cls.get_pool(tenant_id)
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
@ -86,24 +52,17 @@ class CreditPoolService:
raise QuotaExceededError(
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
)
try:
with Session(db.engine) as session:
update_values = {"quota_used": pool.quota_used + credits_required}
with db.session.begin():
update_values = {"quota_used": pool.quota_used + credits_required}
where_conditions = [
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
]
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
db.session.execute(stmt)
@classmethod
def check_deduct_credits(cls, tenant_id: str, credits_required: int) -> bool:
"""check and deduct credits"""
pool = cls.get_pool(tenant_id)
if not pool:
return False
if pool.remaining_credits < credits_required:
return False
return True
where_conditions = [
TenantCreditPool.pool_type == pool_type,
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
]
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
session.execute(stmt)
session.commit()
except Exception:
raise QuotaExceededError("Failed to deduct credits")

View File

@ -49,8 +49,14 @@ class WorkspaceService:
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
pool = CreditPoolService.get_or_create_pool(tenant_id=tenant.id)
tenant_info["trial_credits"] = pool.quota_limit
tenant_info["trial_credits_used"] = pool.quota_used
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
if paid_pool:
tenant_info["trial_credits"] = paid_pool.quota_limit
tenant_info["trial_credits_used"] = paid_pool.quota_used
else:
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
if trial_pool:
tenant_info["trial_credits"] = trial_pool.quota_limit
tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info