From fe0802262c533c285a20d1352bdafc52156e575c Mon Sep 17 00:00:00 2001 From: zyssyz123 <916125788@qq.com> Date: Thu, 8 Jan 2026 13:17:30 +0800 Subject: [PATCH] feat: credit pool (#30720) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../feature/hosted_service/__init__.py | 266 +++++++++++++++++- .../console/workspace/workspace.py | 3 + api/core/hosting_configuration.py | 137 ++++++++- api/core/provider_manager.py | 68 +++-- api/core/workflow/nodes/llm/llm_utils.py | 52 ++-- .../update_provider_when_message_created.py | 46 ++- ...12_25_1039-7df29de0f6be_add_credit_pool.py | 46 +++ api/models/__init__.py | 2 + api/models/model.py | 30 +- api/services/account_service.py | 5 + api/services/credit_pool_service.py | 85 ++++++ api/services/feature_service.py | 4 + api/services/workspace_service.py | 17 +- .../services/test_account_service.py | 9 +- 14 files changed, 694 insertions(+), 76 deletions(-) create mode 100644 api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py create mode 100644 api/services/credit_pool_service.py diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 4ad30014c7..42ede718c4 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings): default="", ) + HOSTED_POOL_CREDITS: int = Field( + description="Pool credits for hosted service", + default=200, + ) + def get_model_credits(self, model_name: str) -> int: """ Get credit value for a specific model name. @@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings): HOSTED_OPENAI_TRIAL_MODELS: str = Field( description="Comma-separated list of available models for trial access", - default="gpt-3.5-turbo," - "gpt-3.5-turbo-1106," - "gpt-3.5-turbo-instruct," + default="gpt-4," + "gpt-4-turbo-preview," + "gpt-4-turbo-2024-04-09," + "gpt-4-1106-preview," + "gpt-4-0125-preview," + "gpt-4-turbo," + "gpt-4.1," + "gpt-4.1-2025-04-14," + "gpt-4.1-mini," + "gpt-4.1-mini-2025-04-14," + "gpt-4.1-nano," + "gpt-4.1-nano-2025-04-14," + "gpt-3.5-turbo," "gpt-3.5-turbo-16k," "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-1106," "gpt-3.5-turbo-0613," "gpt-3.5-turbo-0125," - "text-davinci-003", - ) - - HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description="Quota limit for hosted OpenAI service usage", - default=200, + "gpt-3.5-turbo-instruct," + "text-davinci-003," + "chatgpt-4o-latest," + "gpt-4o," + "gpt-4o-2024-05-13," + "gpt-4o-2024-08-06," + "gpt-4o-2024-11-20," + "gpt-4o-audio-preview," + "gpt-4o-audio-preview-2025-06-03," + "gpt-4o-mini," + "gpt-4o-mini-2024-07-18," + "o3-mini," + "o3-mini-2025-01-31," + "gpt-5-mini-2025-08-07," + "gpt-5-mini," + "o4-mini," + "o4-mini-2025-04-16," + "gpt-5-chat-latest," + "gpt-5," + "gpt-5-2025-08-07," + "gpt-5-nano," + "gpt-5-nano-2025-08-07", ) HOSTED_OPENAI_PAID_ENABLED: bool = Field( @@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings): "gpt-4-turbo-2024-04-09," "gpt-4-1106-preview," "gpt-4-0125-preview," + "gpt-4-turbo," + "gpt-4.1," + "gpt-4.1-2025-04-14," + "gpt-4.1-mini," + "gpt-4.1-mini-2025-04-14," + "gpt-4.1-nano," + "gpt-4.1-nano-2025-04-14," "gpt-3.5-turbo," "gpt-3.5-turbo-16k," "gpt-3.5-turbo-16k-0613," @@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings): "gpt-3.5-turbo-0613," "gpt-3.5-turbo-0125," "gpt-3.5-turbo-instruct," - "text-davinci-003", + "text-davinci-003," + "chatgpt-4o-latest," + "gpt-4o," + "gpt-4o-2024-05-13," + "gpt-4o-2024-08-06," + "gpt-4o-2024-11-20," + "gpt-4o-audio-preview," + "gpt-4o-audio-preview-2025-06-03," + "gpt-4o-mini," + "gpt-4o-mini-2024-07-18," + "o3-mini," + "o3-mini-2025-01-31," + "gpt-5-mini-2025-08-07," + "gpt-5-mini," + "o4-mini," + "o4-mini-2025-04-16," + "gpt-5-chat-latest," + "gpt-5," + "gpt-5-2025-08-07," + "gpt-5-nano," + "gpt-5-nano-2025-08-07", + ) + + +class HostedGeminiConfig(BaseSettings): + """ + Configuration for fetching Gemini service + """ + + HOSTED_GEMINI_API_KEY: str | None = Field( + description="API key for hosted Gemini service", + default=None, + ) + + HOSTED_GEMINI_API_BASE: str | None = Field( + description="Base URL for hosted Gemini API", + default=None, + ) + + HOSTED_GEMINI_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted Gemini service", + default=None, + ) + + HOSTED_GEMINI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Gemini service", + default=False, + ) + + HOSTED_GEMINI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", + ) + + HOSTED_GEMINI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted gemini service", + default=False, + ) + + HOSTED_GEMINI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,", + ) + + +class HostedXAIConfig(BaseSettings): + """ + Configuration for fetching XAI service + """ + + HOSTED_XAI_API_KEY: str | None = Field( + description="API key for hosted XAI service", + default=None, + ) + + HOSTED_XAI_API_BASE: str | None = Field( + description="Base URL for hosted XAI API", + default=None, + ) + + HOSTED_XAI_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted XAI service", + default=None, + ) + + HOSTED_XAI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted XAI service", + default=False, + ) + + HOSTED_XAI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="grok-3,grok-3-mini,grok-3-mini-fast", + ) + + HOSTED_XAI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted XAI service", + default=False, + ) + + HOSTED_XAI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="grok-3,grok-3-mini,grok-3-mini-fast", + ) + + +class HostedDeepseekConfig(BaseSettings): + """ + Configuration for fetching Deepseek service + """ + + HOSTED_DEEPSEEK_API_KEY: str | None = Field( + description="API key for hosted Deepseek service", + default=None, + ) + + HOSTED_DEEPSEEK_API_BASE: str | None = Field( + description="Base URL for hosted Deepseek API", + default=None, + ) + + HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field( + description="Organization ID for hosted Deepseek service", + default=None, + ) + + HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Deepseek service", + default=False, + ) + + HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="deepseek-chat,deepseek-reasoner", + ) + + HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted Deepseek service", + default=False, + ) + + HOSTED_DEEPSEEK_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="deepseek-chat,deepseek-reasoner", ) @@ -144,16 +326,66 @@ class HostedAnthropicConfig(BaseSettings): default=False, ) - HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( - description="Quota limit for hosted Anthropic service usage", - default=600000, - ) - HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( description="Enable paid access to hosted Anthropic service", default=False, ) + HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="claude-opus-4-20250514," + "claude-sonnet-4-20250514," + "claude-3-5-haiku-20241022," + "claude-3-opus-20240229," + "claude-3-7-sonnet-20250219," + "claude-3-haiku-20240307", + ) + HOSTED_ANTHROPIC_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="claude-opus-4-20250514," + "claude-sonnet-4-20250514," + "claude-3-5-haiku-20241022," + "claude-3-opus-20240229," + "claude-3-7-sonnet-20250219," + "claude-3-haiku-20240307", + ) + + +class HostedTongyiConfig(BaseSettings): + """ + Configuration for hosted Tongyi service + """ + + HOSTED_TONGYI_API_KEY: str | None = Field( + description="API key for hosted Tongyi service", + default=None, + ) + + HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field( + description="Use international endpoint for hosted Tongyi service", + default=False, + ) + + HOSTED_TONGYI_TRIAL_ENABLED: bool = Field( + description="Enable trial access to hosted Tongyi service", + default=False, + ) + + HOSTED_TONGYI_PAID_ENABLED: bool = Field( + description="Enable paid access to hosted Anthropic service", + default=False, + ) + + HOSTED_TONGYI_TRIAL_MODELS: str = Field( + description="Comma-separated list of available models for trial access", + default="", + ) + + HOSTED_TONGYI_PAID_MODELS: str = Field( + description="Comma-separated list of available models for paid access", + default="", + ) + class HostedMinmaxConfig(BaseSettings): """ @@ -246,9 +478,13 @@ class HostedServiceConfig( HostedOpenAiConfig, HostedSparkConfig, HostedZhipuAIConfig, + HostedTongyiConfig, # moderation HostedModerationConfig, # credit config HostedCreditConfig, + HostedGeminiConfig, + HostedXAIConfig, + HostedDeepseekConfig, ): pass diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 909a5ce201..52e6f7d737 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -80,6 +80,9 @@ tenant_fields = { "in_trial": fields.Boolean, "trial_end_reason": fields.String, "custom_config": fields.Raw(attribute="custom_config"), + "trial_credits": fields.Integer, + "trial_credits_used": fields.Integer, + "next_credit_reset_date": fields.Integer, } tenants_fields = { diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index af860a1070..370e64e385 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -56,6 +56,10 @@ class HostingConfiguration: self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax() self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark() self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek() + self.provider_map[f"{DEFAULT_PLUGIN_ID}/tongyi/tongyi"] = self.init_tongyi() self.moderation_config = self.init_moderation_config() @@ -128,7 +132,7 @@ class HostingConfiguration: quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: - hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT + hosted_quota_limit = 0 trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) @@ -156,18 +160,49 @@ class HostingConfiguration: quota_unit=quota_unit, ) - @staticmethod - def init_anthropic() -> HostingProvider: - quota_unit = QuotaUnit.TOKENS + def init_gemini(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_GEMINI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_GEMINI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "google_api_key": dify_config.HOSTED_GEMINI_API_KEY, + } + + if dify_config.HOSTED_GEMINI_API_BASE: + credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_anthropic(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: - hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT - trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) quotas.append(trial_quota) if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: - paid_quota = PaidHostingQuota() + paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: @@ -185,6 +220,94 @@ class HostingConfiguration: quota_unit=quota_unit, ) + def init_tongyi(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_TONGYI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_TONGYI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "dashscope_api_key": dify_config.HOSTED_TONGYI_API_KEY, + "use_international_endpoint": dify_config.HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT, + } + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_xai(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_XAI_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_XAI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "api_key": dify_config.HOSTED_XAI_API_KEY, + } + + if dify_config.HOSTED_XAI_API_BASE: + credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + + def init_deepseek(self) -> HostingProvider: + quota_unit = QuotaUnit.CREDITS + quotas: list[HostingQuota] = [] + + if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED: + hosted_quota_limit = 0 + trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models) + quotas.append(trial_quota) + + if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) + quotas.append(paid_quota) + + if len(quotas) > 0: + credentials = { + "api_key": dify_config.HOSTED_DEEPSEEK_API_KEY, + } + + if dify_config.HOSTED_DEEPSEEK_API_BASE: + credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE + + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) + + return HostingProvider( + enabled=False, + quota_unit=quota_unit, + ) + @staticmethod def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 10d86d1762..fdbfca4330 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -618,18 +618,18 @@ class ProviderManager: ) for quota in configuration.quotas: - if quota.quota_type == ProviderQuotaType.TRIAL: + if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID): # Init trial provider records if not exists - if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: + if quota.quota_type not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic new_provider_record = Provider( tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM, - quota_type=ProviderQuotaType.TRIAL, - quota_limit=quota.quota_limit, # type: ignore + provider_type=ProviderType.SYSTEM.value, + quota_type=quota.quota_type, + quota_limit=0, # type: ignore quota_used=0, is_valid=True, ) @@ -641,8 +641,8 @@ class ProviderManager: stmt = select(Provider).where( Provider.tenant_id == tenant_id, Provider.provider_name == ModelProviderID(provider_name).provider_name, - Provider.provider_type == ProviderType.SYSTEM, - Provider.quota_type == ProviderQuotaType.TRIAL, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == quota.quota_type, ) existed_provider_record = db.session.scalar(stmt) if not existed_provider_record: @@ -912,6 +912,22 @@ class ProviderManager: provider_record ) quota_configurations = [] + + if dify_config.EDITION == "CLOUD": + from services.credit_pool_service import CreditPoolService + + trail_pool = CreditPoolService.get_pool( + tenant_id=tenant_id, + pool_type=ProviderQuotaType.TRIAL.value, + ) + paid_pool = CreditPoolService.get_pool( + tenant_id=tenant_id, + pool_type=ProviderQuotaType.PAID.value, + ) + else: + trail_pool = None + paid_pool = None + for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: @@ -932,16 +948,36 @@ class ProviderManager: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") + if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=trail_pool.quota_used, + quota_limit=trail_pool.quota_limit, + is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=provider_record.quota_used, - quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used - or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=paid_pool.quota_used, + quota_limit=paid_pool.quota_limit, + is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) + + else: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=provider_record.quota_used, + quota_limit=provider_record.quota_limit, + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) quota_configurations.append(quota_configuration) diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 0c545469bc..01e25cbf5c 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.provider_entities import QuotaUnit +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.file.models import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -136,21 +136,37 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - with Session(db.engine) as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) + if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, ) - session.execute(stmt) - session.commit() + elif system_configuration.current_quota_type == ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + else: + with Session(db.engine) as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=naive_utc_now(), + ) + ) + session.execute(stmt) + session.commit() diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 84266ab0fa..1ddcc8f792 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity -from core.entities.provider_entities import QuotaUnit, SystemConfiguration +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration from events.message_event import message_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -134,22 +134,38 @@ def handle(sender: Message, **kwargs): system_configuration=system_configuration, model_name=model_config.model, ) - if used_quota is not None: - quota_update = _ProviderUpdateOperation( - filters=_ProviderUpdateFilters( + if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( tenant_id=tenant_id, - provider_name=ModelProviderID(model_config.provider).provider_name, - provider_type=ProviderType.SYSTEM, - quota_type=provider_configuration.system_configuration.current_quota_type.value, - ), - values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), - additional_filters=_ProviderUpdateAdditionalFilters( - quota_limit_check=True # Provider.quota_limit > Provider.quota_used - ), - description="quota_deduction_update", - ) - updates_to_perform.append(quota_update) + credits_required=used_quota, + pool_type="trial", + ) + elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + else: + quota_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=ModelProviderID(model_config.provider).provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=provider_configuration.system_configuration.current_quota_type.value, + ), + values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), + additional_filters=_ProviderUpdateAdditionalFilters( + quota_limit_check=True # Provider.quota_limit > Provider.quota_used + ), + description="quota_deduction_update", + ) + updates_to_perform.append(quota_update) # Execute all updates start_time = time_module.perf_counter() diff --git a/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py new file mode 100644 index 0000000000..e89fcee7e5 --- /dev/null +++ b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py @@ -0,0 +1,46 @@ +"""add credit pool + +Revision ID: 7df29de0f6be +Revises: 03ea244985ce +Create Date: 2025-12-25 10:39:15.139304 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '7df29de0f6be' +down_revision = '03ea244985ce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tenant_credit_pools', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False), + sa.Column('quota_limit', sa.BigInteger(), nullable=False), + sa.Column('quota_used', sa.BigInteger(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey') + ) + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False) + batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + + with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op: + batch_op.drop_index('tenant_credit_pool_tenant_id_idx') + batch_op.drop_index('tenant_credit_pool_pool_type_idx') + + op.drop_table('tenant_credit_pools') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 906bc3198e..e23de832dc 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -60,6 +60,7 @@ from .model import ( Site, Tag, TagBinding, + TenantCreditPool, TraceAppConfig, UploadFile, ) @@ -177,6 +178,7 @@ __all__ = [ "Tenant", "TenantAccountJoin", "TenantAccountRole", + "TenantCreditPool", "TenantDefaultModel", "TenantPreferredModelProvider", "TenantStatus", diff --git a/api/models/model.py b/api/models/model.py index 46df047237..c791ae15b0 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -12,8 +12,8 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request -from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text +from flask_login import UserMixin # type: ignore[import-untyped] +from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -2073,3 +2073,29 @@ class TraceAppConfig(TypeBase): "created_at": str(self.created_at) if self.created_at else None, "updated_at": str(self.updated_at) if self.updated_at else None, } + + +class TenantCreditPool(Base): + __tablename__ = "tenant_credit_pools" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"), + sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"), + sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"), + ) + + id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()")) + tenant_id = mapped_column(StringUUID, nullable=False) + pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial") + quota_limit = mapped_column(BigInteger, nullable=False, default=0) + quota_used = mapped_column(BigInteger, nullable=False, default=0) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP")) + updated_at = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + + @property + def remaining_credits(self) -> int: + return max(0, self.quota_limit - self.quota_used) + + def has_sufficient_credits(self, required_credits: int) -> bool: + return self.remaining_credits >= required_credits diff --git a/api/services/account_service.py b/api/services/account_service.py index 5a549dc318..d38c9d5a66 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -999,6 +999,11 @@ class TenantService: tenant.encrypt_public_key = generate_key_pair(tenant.id) db.session.commit() + + from services.credit_pool_service import CreditPoolService + + CreditPoolService.create_default_pool(tenant.id) + return tenant @staticmethod diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py new file mode 100644 index 0000000000..1954602571 --- /dev/null +++ b/api/services/credit_pool_service.py @@ -0,0 +1,85 @@ +import logging + +from sqlalchemy import update +from sqlalchemy.orm import Session + +from configs import dify_config +from core.errors.error import QuotaExceededError +from extensions.ext_database import db +from models import TenantCreditPool + +logger = logging.getLogger(__name__) + + +class CreditPoolService: + @classmethod + def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: + """create default credit pool for new tenant""" + credit_pool = TenantCreditPool( + tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial" + ) + db.session.add(credit_pool) + db.session.commit() + return credit_pool + + @classmethod + def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: + """get tenant credit pool""" + return ( + db.session.query(TenantCreditPool) + .filter_by( + tenant_id=tenant_id, + pool_type=pool_type, + ) + .first() + ) + + @classmethod + def check_credits_available( + cls, + tenant_id: str, + credits_required: int, + pool_type: str = "trial", + ) -> bool: + """check if credits are available without deducting""" + pool = cls.get_pool(tenant_id, pool_type) + if not pool: + return False + return pool.remaining_credits >= credits_required + + @classmethod + def check_and_deduct_credits( + cls, + tenant_id: str, + credits_required: int, + pool_type: str = "trial", + ) -> int: + """check and deduct credits, returns actual credits deducted""" + + pool = cls.get_pool(tenant_id, pool_type) + if not pool: + raise QuotaExceededError("Credit pool not found") + + if pool.remaining_credits <= 0: + raise QuotaExceededError("No credits remaining") + + # deduct all remaining credits if less than required + actual_credits = min(credits_required, pool.remaining_credits) + + try: + with Session(db.engine) as session: + stmt = ( + update(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, + ) + .values(quota_used=TenantCreditPool.quota_used + actual_credits) + ) + session.execute(stmt) + session.commit() + except Exception: + logger.exception("Failed to deduct credits for tenant %s", tenant_id) + raise QuotaExceededError("Failed to deduct credits") + + return actual_credits diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 8035adc734..9b853b8337 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -140,6 +140,7 @@ class FeatureModel(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) knowledge_pipeline: KnowledgePipeline = KnowledgePipeline() + next_credit_reset_date: int = 0 class KnowledgeRateLimitModel(BaseModel): @@ -301,6 +302,9 @@ class FeatureService: if "knowledge_pipeline_publish_enabled" in billing_info: features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"] + if "next_credit_reset_date" in billing_info: + features.next_credit_reset_date = billing_info["next_credit_reset_date"] + @classmethod def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): enterprise_info = EnterpriseService.get_info() diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 292ac6e008..3ee41c2e8d 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -31,7 +31,8 @@ class WorkspaceService: assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo + feature = FeatureService.get_features(tenant.id) + can_replace_logo = feature.can_replace_logo if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): base_url = dify_config.FILES_URL @@ -46,5 +47,19 @@ class WorkspaceService: "remove_webapp_brand": remove_webapp_brand, "replace_webapp_logo": replace_webapp_logo, } + if dify_config.EDITION == "CLOUD": + tenant_info["next_credit_reset_date"] = feature.next_credit_reset_date + + from services.credit_pool_service import CreditPoolService + + paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid") + if paid_pool: + tenant_info["trial_credits"] = paid_pool.quota_limit + tenant_info["trial_credits_used"] = paid_pool.quota_used + else: + trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial") + if trial_pool: + tenant_info["trial_credits"] = trial_pool.quota_limit + tenant_info["trial_credits_used"] = trial_pool.quota_used return tenant_info diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 627a04bcd0..e35ba74c56 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -619,8 +619,13 @@ class TestTenantService: mock_tenant_instance.name = "Test User's Workspace" mock_tenant_class.return_value = mock_tenant_instance - # Execute test - TenantService.create_owner_tenant_if_not_exist(mock_account) + # Mock the db import in CreditPoolService to avoid database connection + with patch("services.credit_pool_service.db") as mock_credit_pool_db: + mock_credit_pool_db.session.add = MagicMock() + mock_credit_pool_db.session.commit = MagicMock() + + # Execute test + TenantService.create_owner_tenant_if_not_exist(mock_account) # Verify tenant was created with correct parameters mock_db_dependencies["db"].session.add.assert_called()