diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 4ad30014c7..538c55d931 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings): default="", ) + HOSTED_POOL_CREDITS: int = Field( + description="Pool credits for hosted service", + default=200, + ) + def get_model_credits(self, model_name: str) -> int: """ Get credit value for a specific model name. @@ -70,11 +75,6 @@ class HostedOpenAiConfig(BaseSettings): "text-davinci-003", ) - HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description="Quota limit for hosted OpenAI service usage", - default=200, - ) - HOSTED_OPENAI_PAID_ENABLED: bool = Field( description="Enable paid access to hosted OpenAI service", default=False, @@ -98,6 +98,99 @@ class HostedOpenAiConfig(BaseSettings): ) +class HostedGeminiConfig(BaseSettings): + """ + Configuration for fetching Gemini service + """ + + HOSTED_GEMINI_API_KEY: str | None = Field( + description="API key for hosted Gemini service", + default=None, + ) + + HOSTED_GEMINI_API_BASE: str | None = Field( + description="Base URL for hosted Gemini API", + default=None, + ) + + HOSTED_GEMINI_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted Gemini service", + default=None, + ) + + HOSTED_GEMINI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Gemini service", + default=False, + ) + + HOSTED_GEMINI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", + ) + + +class HostedXAIConfig(BaseSettings): + """ + Configuration for fetching XAI service + """ + + HOSTED_XAI_API_KEY: str | None = Field( + description="API key for hosted XAI service", + default=None, + ) + + HOSTED_XAI_API_BASE: str | None = Field( + description="Base URL for hosted XAI API", + default=None, + ) + + HOSTED_XAI_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted XAI service", + default=None, + ) + + HOSTED_XAI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted XAI service", + default=False, + ) + + HOSTED_XAI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="grok-3,grok-3-mini,grok-3-mini-fast", + ) + + +class HostedDeepseekConfig(BaseSettings): + """ + Configuration for fetching Deepseek service + """ + + HOSTED_DEEPSEEK_API_KEY: str | None = Field( + description="API key for hosted Deepseek service", + default=None, + ) + + HOSTED_DEEPSEEK_API_BASE: str | None = Field( + description="Base URL for hosted Deepseek API", + default=None, + ) + + HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted Deepseek service", + default=None, + ) + + HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Deepseek service", + default=False, + ) + + HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="deepseek-chat,deepseek-reasoner", + ) + + class HostedAzureOpenAiConfig(BaseSettings): """ Configuration for hosted Azure OpenAI service @@ -144,16 +237,22 @@ class HostedAnthropicConfig(BaseSettings): default=False, ) - HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( - description="Quota limit for hosted Anthropic service usage", - default=600000, - ) - HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( description="Enable paid access to hosted Anthropic service", default=False, ) + HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="claude-opus-4-20250514," + "claude-opus-4-20250514," + "claude-sonnet-4-20250514," + "claude-3-5-haiku-20241022," + "claude-3-opus-20240229," + "claude-3-7-sonnet-20250219," + "claude-3-haiku-20240307", + ) + class HostedMinmaxConfig(BaseSettings): """ @@ -250,5 +349,8 @@ class HostedServiceConfig( HostedModerationConfig, # credit config HostedCreditConfig, + HostedGeminiConfig, + HostedXAIConfig, + HostedDeepseekConfig, ): pass diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 6bec70b5da..5242e6e04c 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -51,6 +51,8 @@ tenant_fields = { "in_trial": fields.Boolean, "trial_end_reason": fields.String, "custom_config": fields.Raw(attribute="custom_config"), + "trial_credits": fields.Integer, + "trial_credits_used": fields.Integer, } tenants_fields = { diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index af860a1070..7aafa4bc80 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -56,6 +56,9 @@ class HostingConfiguration: self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek() self.moderation_config = self.init_moderation_config() @@ -128,7 +131,7 @@ class HostingConfiguration: quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: - hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT + hosted_quota_limit = 0 trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) @@ -156,14 +159,39 @@ class HostingConfiguration: quota_unit=quota_unit, ) - @staticmethod - def init_anthropic() -> HostingProvider: - quota_unit = QuotaUnit.TOKENS + def init_gemini(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_GEMINI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) + quotas.append(trial_quota) + + if len(quotas) > 0: + credentials = { + "google_api_key": dify_config.HOSTED_GEMINI_API_KEY, + } + + if dify_config.HOSTED_GEMINI_API_BASE: + credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_anthropic(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: - hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT - trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) quotas.append(trial_quota) if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: @@ -185,6 +213,56 @@ class HostingConfiguration: quota_unit=quota_unit, ) + def init_xai(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_XAI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if len(quotas) > 0: + credentials = { + "api_key": dify_config.HOSTED_XAI_API_KEY, + } + + if dify_config.HOSTED_XAI_API_BASE: + credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_deepseek(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if len(quotas) > 0: + credentials = { + "api_key": dify_config.HOSTED_DEEPSEEK_API_KEY, + } + + if dify_config.HOSTED_DEEPSEEK_API_BASE: + credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + @staticmethod def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 499d39bd5d..1ac02d9b6a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -629,7 +629,7 @@ class ProviderManager: provider_name=ModelProviderID(provider_name).provider_name, provider_type=ProviderType.SYSTEM.value, quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota.quota_limit, # type: ignore + quota_limit=0, # type: ignore quota_used=0, is_valid=True, ) @@ -912,6 +912,16 @@ class ProviderManager: provider_record ) quota_configurations = [] + + if dify_config.EDITION == "CLOUD": + from services.credit_pool_service import CreditPoolService + + pool = CreditPoolService.get_or_create_pool( + tenant_id=tenant_id, + ) + else: + pool = None + for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: @@ -932,16 +942,26 @@ class ProviderManager: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") + if provider_quota.quota_type == ProviderQuotaType.TRIAL and pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=pool.quota_used, + quota_limit=pool.quota_limit, + is_valid=pool.quota_limit > pool.quota_used or pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=provider_record.quota_used, - quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used - or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + else: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=provider_record.quota_used, + quota_limit=provider_record.quota_limit, + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) quota_configurations.append(quota_configuration) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 27efa539dc..d787b23fac 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity -from core.entities.provider_entities import QuotaUnit, SystemConfiguration +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration from events.message_event import message_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -135,20 +135,28 @@ def handle(sender: Message, **kwargs): ) if used_quota is not None: - quota_update = _ProviderUpdateOperation( - filters=_ProviderUpdateFilters( + if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( tenant_id=tenant_id, - provider_name=ModelProviderID(model_config.provider).provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=provider_configuration.system_configuration.current_quota_type.value, - ), - values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), - additional_filters=_ProviderUpdateAdditionalFilters( - quota_limit_check=True # Provider.quota_limit > Provider.quota_used - ), - description="quota_deduction_update", - ) - updates_to_perform.append(quota_update) + credits_required=used_quota, + ) + else: + quota_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=ModelProviderID(model_config.provider).provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=provider_configuration.system_configuration.current_quota_type.value, + ), + values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), + additional_filters=_ProviderUpdateAdditionalFilters( + quota_limit_check=True # Provider.quota_limit > Provider.quota_used + ), + description="quota_deduction_update", + ) + updates_to_perform.append(quota_update) # Execute all updates start_time = time_module.perf_counter() diff --git a/api/migrations/versions/2025_09_25_1520-58a70d22fdbd_add_table_credit_pool.py b/api/migrations/versions/2025_09_25_1520-58a70d22fdbd_add_table_credit_pool.py new file mode 100644 index 0000000000..d298d885f4 --- /dev/null +++ b/api/migrations/versions/2025_09_25_1520-58a70d22fdbd_add_table_credit_pool.py @@ -0,0 +1,96 @@ +"""add table credit pool + +Revision ID: 58a70d22fdbd +Revises: 68519ad5cd18 +Create Date: 2025-09-25 15:20:40.367078 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '58a70d22fdbd' +down_revision = '68519ad5cd18' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tenant_credit_pools', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('pool_type', sa.String(length=40), nullable=False), + sa.Column('quota_limit', sa.BigInteger(), nullable=False), + sa.Column('quota_used', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey') + ) + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False) + batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False) + # Data migration: Move trial quota data from providers to tenant_credit_pools + migrate_trial_quota_data() + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.drop_index('tenant_credit_pool_tenant_id_idx') + batch_op.drop_index('tenant_credit_pool_pool_type_idx') + + op.drop_table('tenant_credit_pools') + # ### end Alembic commands ### + + +def migrate_trial_quota_data(): + """ + Migrate quota data from providers table to tenant_credit_pools table + for providers with quota_type='trial', provider_name='openai', provider_type='system' + """ + # Create connection + bind = op.get_bind() + + # Query providers that match the criteria + select_sql = sa.text(""" + SELECT tenant_id, quota_limit, quota_used + FROM providers + WHERE quota_type = 'trial' + AND provider_name = 'openai' + AND provider_type = 'system' + AND quota_limit IS NOT NULL + """) + + result = bind.execute(select_sql) + providers_data = result.fetchall() + + # Insert data into tenant_credit_pools + for provider_data in providers_data: + tenant_id, quota_limit, quota_used = provider_data + + # Check if credit pool already exists for this tenant + check_sql = sa.text(""" + SELECT COUNT(*) + FROM tenant_credit_pools + WHERE tenant_id = :tenant_id AND pool_type = 'trial' + """) + + existing_count = bind.execute(check_sql, {"tenant_id": tenant_id}).scalar() + + if existing_count == 0: + # Insert new credit pool record + insert_sql = sa.text(""" + INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at) + VALUES (:tenant_id, 'trial', :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + """) + + bind.execute(insert_sql, { + "tenant_id": tenant_id, + "quota_limit": quota_limit or 0, + "quota_used": quota_used or 0 + }) diff --git a/api/models/__init__.py b/api/models/__init__.py index 779484283f..6cdb7529e3 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -53,6 +53,7 @@ from .model import ( Site, Tag, TagBinding, + TenantCreditPool, TraceAppConfig, UploadFile, ) @@ -159,6 +160,7 @@ __all__ = [ "Tenant", "TenantAccountJoin", "TenantAccountRole", + "TenantCreditPool", "TenantDefaultModel", "TenantPreferredModelProvider", "TenantStatus", diff --git a/api/models/model.py b/api/models/model.py index a8218c3a4e..30ec03de97 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] -from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text +from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -1944,3 +1944,29 @@ class TraceAppConfig(Base): "created_at": str(self.created_at) if self.created_at else None, "updated_at": str(self.updated_at) if self.updated_at else None, } + + +class TenantCreditPool(Base): + __tablename__ = "tenant_credit_pools" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"), + sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"), + sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"), + ) + + id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + quota_limit = mapped_column(BigInteger, nullable=False, default=0) + quota_used = mapped_column(BigInteger, nullable=False, default=0) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP")) + updated_at = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + @property + def remaining_credits(self) -> int: + return max(0, self.quota_limit - self.quota_used) + + def has_sufficient_credits(self, required_credits: int) -> bool: + return self.remaining_credits >= required_credits diff --git a/api/services/account_service.py b/api/services/account_service.py index 0e699d16da..21637a69e5 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -993,6 +993,11 @@ class TenantService: tenant.encrypt_public_key = generate_key_pair(tenant.id) db.session.commit() + + from services.credit_pool_service import CreditPoolService + + CreditPoolService.create_default_pool(tenant.id) + return tenant @staticmethod diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py new file mode 100644 index 0000000000..2bf9f4118f --- /dev/null +++ b/api/services/credit_pool_service.py @@ -0,0 +1,107 @@ +from typing import Optional + +from sqlalchemy import update + +from configs import dify_config +from core.errors.error import QuotaExceededError +from extensions.ext_database import db +from models import TenantCreditPool + + +class CreditPoolService: + @classmethod + def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: + """create default credit pool for new tenant""" + credit_pool = TenantCreditPool( + tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + ) + db.session.add(credit_pool) + db.session.commit() + return credit_pool + + @classmethod + def get_pool(cls, tenant_id: str) -> Optional[TenantCreditPool]: + """get tenant credit pool""" + return ( + db.session.query(TenantCreditPool) + .filter_by( + tenant_id=tenant_id, + ) + .first() + ) + + @classmethod + def get_or_create_pool(cls, tenant_id: str) -> TenantCreditPool: + """get or create credit pool""" + # First try to get existing pool + pool = cls.get_pool(tenant_id) + if pool: + return pool + + # Create new pool if not exists, handle race condition + try: + # Double-check in case another thread created it + pool = ( + db.session.query(TenantCreditPool) + .filter_by( + tenant_id=tenant_id, + ) + .first() + ) + if pool: + return pool + + # Create new pool + pool = TenantCreditPool( + tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + ) + db.session.add(pool) + db.session.commit() + + except Exception: + # If creation fails (e.g., due to race condition), rollback and try to get existing one + db.session.rollback() + pool = cls.get_pool(tenant_id) + if not pool: + raise + + return pool + + @classmethod + def check_and_deduct_credits( + cls, + tenant_id: str, + credits_required: int, + ) -> bool: + """check and deduct credits""" + pool = cls.get_pool(tenant_id) + if not pool: + raise QuotaExceededError("Credit pool not found") + + if pool.remaining_credits < credits_required: + raise QuotaExceededError( + f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}" + ) + + with db.session.begin(): + update_values = {"quota_used": pool.quota_used + credits_required} + + where_conditions = [ + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit, + ] + stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values) + db.session.execute(stmt) + + return True + + @classmethod + def check_deduct_credits(cls, tenant_id: str, credits_required: int) -> bool: + """check and deduct credits""" + pool = cls.get_pool(tenant_id) + if not pool: + return False + + if pool.remaining_credits < credits_required: + return False + return True diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 292ac6e008..a21aac1984 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -46,5 +46,11 @@ class WorkspaceService: "remove_webapp_brand": remove_webapp_brand, "replace_webapp_logo": replace_webapp_logo, } + if dify_config.EDITION == "CLOUD": + from services.credit_pool_service import CreditPoolService + + pool = CreditPoolService.get_or_create_pool(tenant_id=tenant.id) + tenant_info["trial_credits"] = pool.quota_limit + tenant_info["trial_credits_used"] = pool.quota_used return tenant_info