add paid credit

This commit is contained in:
Yansong Zhang 2025-09-26 12:49:26 +08:00
parent 2ff280c4bf
commit e974c696f7
8 changed files with 192 additions and 126 deletions

View File

@ -128,6 +128,16 @@ class HostedGeminiConfig(BaseSettings):
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
) )
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings): class HostedXAIConfig(BaseSettings):
""" """
@ -159,6 +169,16 @@ class HostedXAIConfig(BaseSettings):
default="grok-3,grok-3-mini,grok-3-mini-fast", default="grok-3,grok-3-mini,grok-3-mini-fast",
) )
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings): class HostedDeepseekConfig(BaseSettings):
""" """
@ -190,6 +210,16 @@ class HostedDeepseekConfig(BaseSettings):
default="deepseek-chat,deepseek-reasoner", default="deepseek-chat,deepseek-reasoner",
) )
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedAzureOpenAiConfig(BaseSettings): class HostedAzureOpenAiConfig(BaseSettings):
""" """
@ -252,6 +282,16 @@ class HostedAnthropicConfig(BaseSettings):
"claude-3-7-sonnet-20250219," "claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
) )
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedMinmaxConfig(BaseSettings): class HostedMinmaxConfig(BaseSettings):

View File

@ -169,6 +169,11 @@ class HostingConfiguration:
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota) quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0: if len(quotas) > 0:
credentials = { credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY, "google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
@ -196,7 +201,8 @@ class HostingConfiguration:
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota() paid_quota = PaidHostingQuota()
quotas.append(paid_quota) paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
quotas.append(paid_quota,restrict_models=paid_models)
if len(quotas) > 0: if len(quotas) > 0:
credentials = { credentials = {
@ -223,6 +229,11 @@ class HostingConfiguration:
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota) quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0: if len(quotas) > 0:
credentials = { credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY, "api_key": dify_config.HOSTED_XAI_API_KEY,
@ -248,6 +259,11 @@ class HostingConfiguration:
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota) quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0: if len(quotas) > 0:
credentials = { credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY, "api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,

View File

@ -898,11 +898,17 @@ class ProviderManager:
if dify_config.EDITION == "CLOUD": if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService from services.credit_pool_service import CreditPoolService
pool = CreditPoolService.get_or_create_pool( trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id, tenant_id=tenant_id,
pool_type="trial",
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type="paid",
) )
else: else:
pool = None trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas: for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type not in quota_type_to_provider_records_dict:
@ -924,13 +930,23 @@ class ProviderManager:
raise ValueError("quota_used is None") raise ValueError("quota_used is None")
if provider_record.quota_limit is None: if provider_record.quota_limit is None:
raise ValueError("quota_limit is None") raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and pool is not None: if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration( quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type, quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=pool.quota_used, quota_used=trail_pool.quota_used,
quota_limit=pool.quota_limit, quota_limit=trail_pool.quota_limit,
is_valid=pool.quota_limit > pool.quota_used or pool.quota_limit == -1, is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models, restrict_models=provider_quota.restrict_models,
) )

View File

@ -23,7 +23,7 @@ from libs.datetime_utils import naive_utc_now
from models.model import Conversation from models.model import Conversation
from models.provider import Provider, ProviderType from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
from core.entities.provider_entities import ProviderQuotaType
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1 used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None: if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = ( if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
update(Provider) from services.credit_pool_service import CreditPoolService
.where( CreditPoolService.check_and_deduct_credits(
Provider.tenant_id == tenant_id, tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration. credits_required=used_quota,
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
) )
session.execute(stmt) elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
session.commit() from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -133,18 +133,24 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration, system_configuration=system_configuration,
model_name=model_config.model, model_name=model_config.model,
) )
logger.info("used_quota: %s", used_quota)
if used_quota is not None: if used_quota is not None:
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL: if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
logger.info("deduct credits")
from services.credit_pool_service import CreditPoolService from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits( CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id, tenant_id=tenant_id,
credits_required=used_quota, credits_required=used_quota,
pool_type="trial",
)
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
) )
else: else:
logger.info("update provider quota")
quota_update = _ProviderUpdateOperation( quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters( filters=_ProviderUpdateFilters(
tenant_id=tenant_id, tenant_id=tenant_id,

View File

@ -32,8 +32,8 @@ def upgrade():
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False) batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False) batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
# Data migration: Move trial quota data from providers to tenant_credit_pools # Data migration: Move quota data from providers to tenant_credit_pools
migrate_trial_quota_data() migrate_quota_data()
# ### end Alembic commands ### # ### end Alembic commands ###
@ -48,49 +48,57 @@ def downgrade():
# ### end Alembic commands ### # ### end Alembic commands ###
def migrate_trial_quota_data(): def migrate_quota_data():
""" """
Migrate quota data from providers table to tenant_credit_pools table Migrate quota data from providers table to tenant_credit_pools table
for providers with quota_type='trial', provider_name='openai', provider_type='system' for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
""" """
# Create connection # Create connection
bind = op.get_bind() bind = op.get_bind()
# Query providers that match the criteria # Define quota type mappings
select_sql = sa.text(""" quota_type_mappings = ['trial', 'paid']
SELECT tenant_id, quota_limit, quota_used
FROM providers
WHERE quota_type = 'trial'
AND provider_name = 'openai'
AND provider_type = 'system'
AND quota_limit IS NOT NULL
""")
result = bind.execute(select_sql) for quota_type in quota_type_mappings:
providers_data = result.fetchall() # Query providers that match the criteria
select_sql = sa.text("""
# Insert data into tenant_credit_pools SELECT tenant_id, quota_limit, quota_used
for provider_data in providers_data: FROM providers
tenant_id, quota_limit, quota_used = provider_data WHERE quota_type = :quota_type
AND provider_name = 'openai'
# Check if credit pool already exists for this tenant AND provider_type = 'system'
check_sql = sa.text(""" AND quota_limit IS NOT NULL
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = 'trial'
""") """)
existing_count = bind.execute(check_sql, {"tenant_id": tenant_id}).scalar() result = bind.execute(select_sql, {"quota_type": quota_type})
providers_data = result.fetchall()
if existing_count == 0: # Insert data into tenant_credit_pools
# Insert new credit pool record for provider_data in providers_data:
insert_sql = sa.text(""" tenant_id, quota_limit, quota_used = provider_data
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, 'trial', :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) # Check if credit pool already exists for this tenant and pool type
check_sql = sa.text("""
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
""") """)
bind.execute(insert_sql, { existing_count = bind.execute(check_sql, {
"tenant_id": tenant_id, "tenant_id": tenant_id,
"quota_limit": quota_limit or 0, "pool_type": quota_type
"quota_used": quota_used or 0 }).scalar()
})
if existing_count == 0:
# Insert new credit pool record
insert_sql = sa.text("""
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""")
bind.execute(insert_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type,
"quota_limit": quota_limit or 0,
"quota_used": quota_used or 0
})

View File

@ -7,6 +7,7 @@ from configs import dify_config
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
from extensions.ext_database import db from extensions.ext_database import db
from models import TenantCreditPool from models import TenantCreditPool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,62 +24,27 @@ class CreditPoolService:
return credit_pool return credit_pool
@classmethod @classmethod
def get_pool(cls, tenant_id: str) -> Optional[TenantCreditPool]: def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
"""get tenant credit pool""" """get tenant credit pool"""
return ( return (
db.session.query(TenantCreditPool) db.session.query(TenantCreditPool)
.filter_by( .filter_by(
tenant_id=tenant_id, tenant_id=tenant_id,
pool_type=pool_type,
) )
.first() .first()
) )
@classmethod
def get_or_create_pool(cls, tenant_id: str) -> TenantCreditPool:
"""get or create credit pool"""
# First try to get existing pool
pool = cls.get_pool(tenant_id)
if pool:
return pool
# Create new pool if not exists, handle race condition
try:
# Double-check in case another thread created it
pool = (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
)
.first()
)
if pool:
return pool
# Create new pool
pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
)
db.session.add(pool)
db.session.commit()
except Exception:
# If creation fails (e.g., due to race condition), rollback and try to get existing one
db.session.rollback()
pool = cls.get_pool(tenant_id)
if not pool:
raise
return pool
@classmethod @classmethod
def check_and_deduct_credits( def check_and_deduct_credits(
cls, cls,
tenant_id: str, tenant_id: str,
credits_required: int, credits_required: int,
pool_type: str = "trial",
): ):
"""check and deduct credits""" """check and deduct credits"""
logger.info("check and deduct credits")
pool = cls.get_pool(tenant_id) pool = cls.get_pool(tenant_id, pool_type)
if not pool: if not pool:
raise QuotaExceededError("Credit pool not found") raise QuotaExceededError("Credit pool not found")
@ -86,24 +52,17 @@ class CreditPoolService:
raise QuotaExceededError( raise QuotaExceededError(
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}" f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
) )
try:
with Session(db.engine) as session:
update_values = {"quota_used": pool.quota_used + credits_required}
with db.session.begin(): where_conditions = [
update_values = {"quota_used": pool.quota_used + credits_required} TenantCreditPool.pool_type == pool_type,
TenantCreditPool.tenant_id == tenant_id,
where_conditions = [ TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
TenantCreditPool.tenant_id == tenant_id, ]
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit, stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
] session.execute(stmt)
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values) session.commit()
db.session.execute(stmt) except Exception:
raise QuotaExceededError("Failed to deduct credits")
@classmethod
def check_deduct_credits(cls, tenant_id: str, credits_required: int) -> bool:
"""check and deduct credits"""
pool = cls.get_pool(tenant_id)
if not pool:
return False
if pool.remaining_credits < credits_required:
return False
return True

View File

@ -49,8 +49,14 @@ class WorkspaceService:
if dify_config.EDITION == "CLOUD": if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService from services.credit_pool_service import CreditPoolService
pool = CreditPoolService.get_or_create_pool(tenant_id=tenant.id) paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
tenant_info["trial_credits"] = pool.quota_limit if paid_pool:
tenant_info["trial_credits_used"] = pool.quota_used tenant_info["trial_credits"] = paid_pool.quota_limit
tenant_info["trial_credits_used"] = paid_pool.quota_used
else:
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
if trial_pool:
tenant_info["trial_credits"] = trial_pool.quota_limit
tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info return tenant_info