diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 538c55d931..6415bd239d 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -128,6 +128,16 @@ class HostedGeminiConfig(BaseSettings): default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", ) + HOSTED_GEMINI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted gemini service", + default=False, + ) + + HOSTED_GEMINI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", + ) + class HostedXAIConfig(BaseSettings): """ @@ -159,6 +169,16 @@ class HostedXAIConfig(BaseSettings): default="grok-3,grok-3-mini,grok-3-mini-fast", ) + HOSTED_XAI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted XAI service", + default=False, + ) + + HOSTED_XAI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="grok-3,grok-3-mini,grok-3-mini-fast", + ) + class HostedDeepseekConfig(BaseSettings): """ @@ -190,6 +210,16 @@ class HostedDeepseekConfig(BaseSettings): default="deepseek-chat,deepseek-reasoner", ) + HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted XAI service", + default=False, + ) + + HOSTED_DEEPSEEK_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="grok-3,grok-3-mini,grok-3-mini-fast", + ) + class HostedAzureOpenAiConfig(BaseSettings): """ @@ -252,6 +282,16 @@ class HostedAnthropicConfig(BaseSettings): "claude-3-7-sonnet-20250219," "claude-3-haiku-20240307", ) + HOSTED_ANTHROPIC_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="claude-opus-4-20250514," + "claude-opus-4-20250514," + "claude-sonnet-4-20250514," + "claude-3-5-haiku-20241022," + "claude-3-opus-20240229," + "claude-3-7-sonnet-20250219," + "claude-3-haiku-20240307", + ) class HostedMinmaxConfig(BaseSettings): diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 7aafa4bc80..ed08ecf57b 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -169,6 +169,11 @@ class HostingConfiguration: trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) + if dify_config.HOSTED_GEMINI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + if len(quotas) > 0: credentials = { "google_api_key": dify_config.HOSTED_GEMINI_API_KEY, @@ -196,7 +201,8 @@ class HostingConfiguration: if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: paid_quota = PaidHostingQuota() - quotas.append(paid_quota) + paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS") + quotas.append(paid_quota,restrict_models=paid_models) if len(quotas) > 0: credentials = { @@ -223,6 +229,11 @@ class HostingConfiguration: trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) quotas.append(trial_quota) + if dify_config.HOSTED_XAI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + if len(quotas) > 0: credentials = { "api_key": dify_config.HOSTED_XAI_API_KEY, @@ -248,6 +259,11 @@ class HostingConfiguration: trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) quotas.append(trial_quota) + if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + if len(quotas) > 0: credentials = { "api_key": dify_config.HOSTED_DEEPSEEK_API_KEY, diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 1ac02d9b6a..2772048d26 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -916,11 +916,17 @@ class ProviderManager: if dify_config.EDITION == "CLOUD": from services.credit_pool_service import CreditPoolService - pool = CreditPoolService.get_or_create_pool( + trail_pool = CreditPoolService.get_pool( tenant_id=tenant_id, + pool_type="trial", + ) + paid_pool = CreditPoolService.get_pool( + tenant_id=tenant_id, + pool_type="paid", ) else: - pool = None + trail_pool = None + paid_pool = None for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: @@ -942,13 +948,23 @@ class ProviderManager: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") - if provider_quota.quota_type == ProviderQuotaType.TRIAL and pool is not None: + if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=pool.quota_used, - quota_limit=pool.quota_limit, - is_valid=pool.quota_limit > pool.quota_used or pool.quota_limit == -1, + quota_used=trail_pool.quota_used, + quota_limit=trail_pool.quota_limit, + is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) + + elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=paid_pool.quota_used, + quota_limit=paid_pool.quota_limit, + is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, restrict_models=provider_quota.restrict_models, ) diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index ad969cdad1..054fbe033d 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -23,7 +23,7 @@ from libs.datetime_utils import naive_utc_now from models.model import Conversation from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID - +from core.entities.provider_entities import ProviderQuotaType from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError @@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - with Session(db.engine) as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) + + if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, ) - session.execute(stmt) - session.commit() + elif system_configuration.current_quota_type == ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + else: + with Session(db.engine) as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=naive_utc_now(), + ) + ) + session.execute(stmt) + session.commit() diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 9efe9a79af..12e0961bcc 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -133,18 +133,24 @@ def handle(sender: Message, **kwargs): system_configuration=system_configuration, model_name=model_config.model, ) - logger.info("used_quota: %s", used_quota) if used_quota is not None: if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL: - logger.info("deduct credits") from services.credit_pool_service import CreditPoolService CreditPoolService.check_and_deduct_credits( tenant_id=tenant_id, credits_required=used_quota, + pool_type="trial", + ) + elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", ) else: - logger.info("update provider quota") quota_update = _ProviderUpdateOperation( filters=_ProviderUpdateFilters( tenant_id=tenant_id, diff --git a/api/migrations/versions/2025_09_25_1520-58a70d22fdbd_add_table_credit_pool.py b/api/migrations/versions/2025_09_25_1520-58a70d22fdbd_add_table_credit_pool.py index d298d885f4..b050008fc2 100644 --- a/api/migrations/versions/2025_09_25_1520-58a70d22fdbd_add_table_credit_pool.py +++ b/api/migrations/versions/2025_09_25_1520-58a70d22fdbd_add_table_credit_pool.py @@ -32,8 +32,8 @@ def upgrade(): with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False) batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False) - # Data migration: Move trial quota data from providers to tenant_credit_pools - migrate_trial_quota_data() + # Data migration: Move quota data from providers to tenant_credit_pools + migrate_quota_data() # ### end Alembic commands ### @@ -48,49 +48,57 @@ def downgrade(): # ### end Alembic commands ### -def migrate_trial_quota_data(): +def migrate_quota_data(): """ Migrate quota data from providers table to tenant_credit_pools table - for providers with quota_type='trial', provider_name='openai', provider_type='system' + for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system' """ # Create connection bind = op.get_bind() - # Query providers that match the criteria - select_sql = sa.text(""" - SELECT tenant_id, quota_limit, quota_used - FROM providers - WHERE quota_type = 'trial' - AND provider_name = 'openai' - AND provider_type = 'system' - AND quota_limit IS NOT NULL - """) + # Define quota type mappings + quota_type_mappings = ['trial', 'paid'] - result = bind.execute(select_sql) - providers_data = result.fetchall() - - # Insert data into tenant_credit_pools - for provider_data in providers_data: - tenant_id, quota_limit, quota_used = provider_data - - # Check if credit pool already exists for this tenant - check_sql = sa.text(""" - SELECT COUNT(*) - FROM tenant_credit_pools - WHERE tenant_id = :tenant_id AND pool_type = 'trial' + for quota_type in quota_type_mappings: + # Query providers that match the criteria + select_sql = sa.text(""" + SELECT tenant_id, quota_limit, quota_used + FROM providers + WHERE quota_type = :quota_type + AND provider_name = 'openai' + AND provider_type = 'system' + AND quota_limit IS NOT NULL """) - existing_count = bind.execute(check_sql, {"tenant_id": tenant_id}).scalar() + result = bind.execute(select_sql, {"quota_type": quota_type}) + providers_data = result.fetchall() - if existing_count == 0: - # Insert new credit pool record - insert_sql = sa.text(""" - INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at) - VALUES (:tenant_id, 'trial', :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + # Insert data into tenant_credit_pools + for provider_data in providers_data: + tenant_id, quota_limit, quota_used = provider_data + + # Check if credit pool already exists for this tenant and pool type + check_sql = sa.text(""" + SELECT COUNT(*) + FROM tenant_credit_pools + WHERE tenant_id = :tenant_id AND pool_type = :pool_type """) - bind.execute(insert_sql, { + existing_count = bind.execute(check_sql, { "tenant_id": tenant_id, - "quota_limit": quota_limit or 0, - "quota_used": quota_used or 0 - }) + "pool_type": quota_type + }).scalar() + + if existing_count == 0: + # Insert new credit pool record + insert_sql = sa.text(""" + INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at) + VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + """) + + bind.execute(insert_sql, { + "tenant_id": tenant_id, + "pool_type": quota_type, + "quota_limit": quota_limit or 0, + "quota_used": quota_used or 0 + }) diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index f72686b9ff..fc2d875bd0 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -7,6 +7,7 @@ from configs import dify_config from core.errors.error import QuotaExceededError from extensions.ext_database import db from models import TenantCreditPool +from sqlalchemy.orm import Session logger = logging.getLogger(__name__) @@ -23,62 +24,27 @@ class CreditPoolService: return credit_pool @classmethod - def get_pool(cls, tenant_id: str) -> Optional[TenantCreditPool]: + def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]: """get tenant credit pool""" return ( db.session.query(TenantCreditPool) .filter_by( tenant_id=tenant_id, + pool_type=pool_type, ) .first() ) - @classmethod - def get_or_create_pool(cls, tenant_id: str) -> TenantCreditPool: - """get or create credit pool""" - # First try to get existing pool - pool = cls.get_pool(tenant_id) - if pool: - return pool - - # Create new pool if not exists, handle race condition - try: - # Double-check in case another thread created it - pool = ( - db.session.query(TenantCreditPool) - .filter_by( - tenant_id=tenant_id, - ) - .first() - ) - if pool: - return pool - - # Create new pool - pool = TenantCreditPool( - tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" - ) - db.session.add(pool) - db.session.commit() - - except Exception: - # If creation fails (e.g., due to race condition), rollback and try to get existing one - db.session.rollback() - pool = cls.get_pool(tenant_id) - if not pool: - raise - - return pool - @classmethod def check_and_deduct_credits( cls, tenant_id: str, credits_required: int, + pool_type: str = "trial", ): """check and deduct credits""" - logger.info("check and deduct credits") - pool = cls.get_pool(tenant_id) + + pool = cls.get_pool(tenant_id, pool_type) if not pool: raise QuotaExceededError("Credit pool not found") @@ -86,24 +52,17 @@ class CreditPoolService: raise QuotaExceededError( f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}" ) + try: + with Session(db.engine) as session: + update_values = {"quota_used": pool.quota_used + credits_required} - with db.session.begin(): - update_values = {"quota_used": pool.quota_used + credits_required} - - where_conditions = [ - TenantCreditPool.tenant_id == tenant_id, - TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit, - ] - stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values) - db.session.execute(stmt) - - @classmethod - def check_deduct_credits(cls, tenant_id: str, credits_required: int) -> bool: - """check and deduct credits""" - pool = cls.get_pool(tenant_id) - if not pool: - return False - - if pool.remaining_credits < credits_required: - return False - return True + where_conditions = [ + TenantCreditPool.pool_type == pool_type, + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit, + ] + stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values) + session.execute(stmt) + session.commit() + except Exception: + raise QuotaExceededError("Failed to deduct credits") diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index a21aac1984..c71a19636d 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -49,8 +49,14 @@ class WorkspaceService: if dify_config.EDITION == "CLOUD": from services.credit_pool_service import CreditPoolService - pool = CreditPoolService.get_or_create_pool(tenant_id=tenant.id) - tenant_info["trial_credits"] = pool.quota_limit - tenant_info["trial_credits_used"] = pool.quota_used + paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid") + if paid_pool: + tenant_info["trial_credits"] = paid_pool.quota_limit + tenant_info["trial_credits_used"] = paid_pool.quota_used + else: + trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial") + if trial_pool: + tenant_info["trial_credits"] = trial_pool.quota_limit + tenant_info["trial_credits_used"] = trial_pool.quota_used return tenant_info