mirror of https://github.com/langgenius/dify.git
add paid credit
This commit is contained in:
parent
db0780cfa8
commit
ab34cea714
|
|
@ -128,6 +128,16 @@ class HostedGeminiConfig(BaseSettings):
|
|||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted gemini service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_GEMINI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
|
||||
)
|
||||
|
||||
|
||||
class HostedXAIConfig(BaseSettings):
|
||||
"""
|
||||
|
|
@ -159,6 +169,16 @@ class HostedXAIConfig(BaseSettings):
|
|||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_XAI_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
|
||||
class HostedDeepseekConfig(BaseSettings):
|
||||
"""
|
||||
|
|
@ -190,6 +210,16 @@ class HostedDeepseekConfig(BaseSettings):
|
|||
default="deepseek-chat,deepseek-reasoner",
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
|
||||
description="Enable paid access to hosted XAI service",
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="grok-3,grok-3-mini,grok-3-mini-fast",
|
||||
)
|
||||
|
||||
|
||||
class HostedAzureOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
|
|
@ -252,6 +282,16 @@ class HostedAnthropicConfig(BaseSettings):
|
|||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
|
||||
description="Comma-separated list of available models for paid access",
|
||||
default="claude-opus-4-20250514,"
|
||||
"claude-opus-4-20250514,"
|
||||
"claude-sonnet-4-20250514,"
|
||||
"claude-3-5-haiku-20241022,"
|
||||
"claude-3-opus-20240229,"
|
||||
"claude-3-7-sonnet-20250219,"
|
||||
"claude-3-haiku-20240307",
|
||||
)
|
||||
|
||||
|
||||
class HostedMinmaxConfig(BaseSettings):
|
||||
|
|
|
|||
|
|
@ -169,6 +169,11 @@ class HostingConfiguration:
|
|||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
|
||||
|
|
@ -196,7 +201,8 @@ class HostingConfiguration:
|
|||
|
||||
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
|
||||
paid_quota = PaidHostingQuota()
|
||||
quotas.append(paid_quota)
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
|
||||
quotas.append(paid_quota,restrict_models=paid_models)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
|
|
@ -223,6 +229,11 @@ class HostingConfiguration:
|
|||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_XAI_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_XAI_API_KEY,
|
||||
|
|
@ -248,6 +259,11 @@ class HostingConfiguration:
|
|||
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
|
||||
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(restrict_models=paid_models)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
if len(quotas) > 0:
|
||||
credentials = {
|
||||
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
|
||||
|
|
|
|||
|
|
@ -916,11 +916,17 @@ class ProviderManager:
|
|||
if dify_config.EDITION == "CLOUD":
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
pool = CreditPoolService.get_or_create_pool(
|
||||
trail_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type="trial",
|
||||
)
|
||||
paid_pool = CreditPoolService.get_pool(
|
||||
tenant_id=tenant_id,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
pool = None
|
||||
trail_pool = None
|
||||
paid_pool = None
|
||||
|
||||
for provider_quota in provider_hosting_configuration.quotas:
|
||||
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
|
||||
|
|
@ -942,13 +948,23 @@ class ProviderManager:
|
|||
raise ValueError("quota_used is None")
|
||||
if provider_record.quota_limit is None:
|
||||
raise ValueError("quota_limit is None")
|
||||
if provider_quota.quota_type == ProviderQuotaType.TRIAL and pool is not None:
|
||||
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=pool.quota_used,
|
||||
quota_limit=pool.quota_limit,
|
||||
is_valid=pool.quota_limit > pool.quota_used or pool.quota_limit == -1,
|
||||
quota_used=trail_pool.quota_used,
|
||||
quota_limit=trail_pool.quota_limit,
|
||||
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
|
||||
quota_configuration = QuotaConfiguration(
|
||||
quota_type=provider_quota.quota_type,
|
||||
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
|
||||
quota_used=paid_pool.quota_used,
|
||||
quota_limit=paid_pool.quota_limit,
|
||||
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
|
||||
restrict_models=provider_quota.restrict_models,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from libs.datetime_utils import naive_utc_now
|
|||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from core.entities.provider_entities import ProviderQuotaType
|
||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||
|
||||
|
||||
|
|
@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
|||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
|
|
|||
|
|
@ -133,18 +133,24 @@ def handle(sender: Message, **kwargs):
|
|||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
logger.info("used_quota: %s", used_quota)
|
||||
if used_quota is not None:
|
||||
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
logger.info("deduct credits")
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="trial",
|
||||
)
|
||||
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
logger.info("update provider quota")
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ def upgrade():
|
|||
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
|
||||
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
|
||||
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
# Data migration: Move trial quota data from providers to tenant_credit_pools
|
||||
migrate_trial_quota_data()
|
||||
# Data migration: Move quota data from providers to tenant_credit_pools
|
||||
migrate_quota_data()
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
|
@ -48,49 +48,57 @@ def downgrade():
|
|||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def migrate_trial_quota_data():
|
||||
def migrate_quota_data():
|
||||
"""
|
||||
Migrate quota data from providers table to tenant_credit_pools table
|
||||
for providers with quota_type='trial', provider_name='openai', provider_type='system'
|
||||
for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
|
||||
"""
|
||||
# Create connection
|
||||
bind = op.get_bind()
|
||||
|
||||
# Query providers that match the criteria
|
||||
select_sql = sa.text("""
|
||||
SELECT tenant_id, quota_limit, quota_used
|
||||
FROM providers
|
||||
WHERE quota_type = 'trial'
|
||||
AND provider_name = 'openai'
|
||||
AND provider_type = 'system'
|
||||
AND quota_limit IS NOT NULL
|
||||
""")
|
||||
# Define quota type mappings
|
||||
quota_type_mappings = ['trial', 'paid']
|
||||
|
||||
result = bind.execute(select_sql)
|
||||
providers_data = result.fetchall()
|
||||
|
||||
# Insert data into tenant_credit_pools
|
||||
for provider_data in providers_data:
|
||||
tenant_id, quota_limit, quota_used = provider_data
|
||||
|
||||
# Check if credit pool already exists for this tenant
|
||||
check_sql = sa.text("""
|
||||
SELECT COUNT(*)
|
||||
FROM tenant_credit_pools
|
||||
WHERE tenant_id = :tenant_id AND pool_type = 'trial'
|
||||
for quota_type in quota_type_mappings:
|
||||
# Query providers that match the criteria
|
||||
select_sql = sa.text("""
|
||||
SELECT tenant_id, quota_limit, quota_used
|
||||
FROM providers
|
||||
WHERE quota_type = :quota_type
|
||||
AND provider_name = 'openai'
|
||||
AND provider_type = 'system'
|
||||
AND quota_limit IS NOT NULL
|
||||
""")
|
||||
|
||||
existing_count = bind.execute(check_sql, {"tenant_id": tenant_id}).scalar()
|
||||
result = bind.execute(select_sql, {"quota_type": quota_type})
|
||||
providers_data = result.fetchall()
|
||||
|
||||
if existing_count == 0:
|
||||
# Insert new credit pool record
|
||||
insert_sql = sa.text("""
|
||||
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
|
||||
VALUES (:tenant_id, 'trial', :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
# Insert data into tenant_credit_pools
|
||||
for provider_data in providers_data:
|
||||
tenant_id, quota_limit, quota_used = provider_data
|
||||
|
||||
# Check if credit pool already exists for this tenant and pool type
|
||||
check_sql = sa.text("""
|
||||
SELECT COUNT(*)
|
||||
FROM tenant_credit_pools
|
||||
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
|
||||
""")
|
||||
|
||||
bind.execute(insert_sql, {
|
||||
existing_count = bind.execute(check_sql, {
|
||||
"tenant_id": tenant_id,
|
||||
"quota_limit": quota_limit or 0,
|
||||
"quota_used": quota_used or 0
|
||||
})
|
||||
"pool_type": quota_type
|
||||
}).scalar()
|
||||
|
||||
if existing_count == 0:
|
||||
# Insert new credit pool record
|
||||
insert_sql = sa.text("""
|
||||
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
|
||||
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
""")
|
||||
|
||||
bind.execute(insert_sql, {
|
||||
"tenant_id": tenant_id,
|
||||
"pool_type": quota_type,
|
||||
"quota_limit": quota_limit or 0,
|
||||
"quota_used": quota_used or 0
|
||||
})
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from configs import dify_config
|
|||
from core.errors.error import QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from models import TenantCreditPool
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -23,62 +24,27 @@ class CreditPoolService:
|
|||
return credit_pool
|
||||
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str) -> Optional[TenantCreditPool]:
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
|
||||
"""get tenant credit pool"""
|
||||
return (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_or_create_pool(cls, tenant_id: str) -> TenantCreditPool:
|
||||
"""get or create credit pool"""
|
||||
# First try to get existing pool
|
||||
pool = cls.get_pool(tenant_id)
|
||||
if pool:
|
||||
return pool
|
||||
|
||||
# Create new pool if not exists, handle race condition
|
||||
try:
|
||||
# Double-check in case another thread created it
|
||||
pool = (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if pool:
|
||||
return pool
|
||||
|
||||
# Create new pool
|
||||
pool = TenantCreditPool(
|
||||
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
|
||||
)
|
||||
db.session.add(pool)
|
||||
db.session.commit()
|
||||
|
||||
except Exception:
|
||||
# If creation fails (e.g., due to race condition), rollback and try to get existing one
|
||||
db.session.rollback()
|
||||
pool = cls.get_pool(tenant_id)
|
||||
if not pool:
|
||||
raise
|
||||
|
||||
return pool
|
||||
|
||||
@classmethod
|
||||
def check_and_deduct_credits(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
credits_required: int,
|
||||
pool_type: str = "trial",
|
||||
):
|
||||
"""check and deduct credits"""
|
||||
logger.info("check and deduct credits")
|
||||
pool = cls.get_pool(tenant_id)
|
||||
|
||||
pool = cls.get_pool(tenant_id, pool_type)
|
||||
if not pool:
|
||||
raise QuotaExceededError("Credit pool not found")
|
||||
|
||||
|
|
@ -86,24 +52,17 @@ class CreditPoolService:
|
|||
raise QuotaExceededError(
|
||||
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
update_values = {"quota_used": pool.quota_used + credits_required}
|
||||
|
||||
with db.session.begin():
|
||||
update_values = {"quota_used": pool.quota_used + credits_required}
|
||||
|
||||
where_conditions = [
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
|
||||
]
|
||||
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
|
||||
db.session.execute(stmt)
|
||||
|
||||
@classmethod
|
||||
def check_deduct_credits(cls, tenant_id: str, credits_required: int) -> bool:
|
||||
"""check and deduct credits"""
|
||||
pool = cls.get_pool(tenant_id)
|
||||
if not pool:
|
||||
return False
|
||||
|
||||
if pool.remaining_credits < credits_required:
|
||||
return False
|
||||
return True
|
||||
where_conditions = [
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
|
||||
]
|
||||
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
|
|
|||
|
|
@ -49,8 +49,14 @@ class WorkspaceService:
|
|||
if dify_config.EDITION == "CLOUD":
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
pool = CreditPoolService.get_or_create_pool(tenant_id=tenant.id)
|
||||
tenant_info["trial_credits"] = pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = pool.quota_used
|
||||
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
|
||||
if paid_pool:
|
||||
tenant_info["trial_credits"] = paid_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = paid_pool.quota_used
|
||||
else:
|
||||
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
|
||||
if trial_pool:
|
||||
tenant_info["trial_credits"] = trial_pool.quota_limit
|
||||
tenant_info["trial_credits_used"] = trial_pool.quota_used
|
||||
|
||||
return tenant_info
|
||||
|
|
|
|||
Loading…
Reference in New Issue