diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 36a78a48f7..1c7138c478 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -620,7 +620,7 @@ class ProviderManager: for quota in configuration.quotas: if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID): # Init trial provider records if not exists - if quota.quota_type not in provider_quota_to_provider_record_dict: + if quota.quota_type not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic new_provider_record = Provider( @@ -957,7 +957,7 @@ class ProviderManager: is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, restrict_models=provider_quota.restrict_models, ) - + elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 194ad43151..7b8bb6368e 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -136,9 +136,9 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: from services.credit_pool_service import CreditPoolService + CreditPoolService.check_and_deduct_credits( tenant_id=tenant_id, credits_required=used_quota, diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 108ee05e45..8ae409809a 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from sqlalchemy import update from sqlalchemy.orm import Session @@ -24,7 +23,7 @@ class CreditPoolService: return credit_pool @classmethod - def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]: + def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: """get tenant credit pool""" return ( db.session.query(TenantCreditPool) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 74df593782..cbba537333 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -131,6 +131,7 @@ class FeatureModel(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) knowledge_pipeline: KnowledgePipeline = KnowledgePipeline() + next_credit_reset_date: int = 0 class KnowledgeRateLimitModel(BaseModel): @@ -282,6 +283,9 @@ class FeatureService: if "knowledge_pipeline_publish_enabled" in billing_info: features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"] + + if "next_credit_reset_date" in billing_info: + features.next_credit_reset_date = billing_info["next_credit_reset_date"] @classmethod def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index c71a19636d..cf2c6cea1c 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -31,7 +31,8 @@ class WorkspaceService: assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo + feature = FeatureService.get_features(tenant.id) + can_replace_logo = feature.can_replace_logo if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): base_url = dify_config.FILES_URL @@ -47,6 +48,9 @@ class WorkspaceService: "replace_webapp_logo": replace_webapp_logo, } if dify_config.EDITION == "CLOUD": + + tenant_info["next_credit_reset_date"] = feature.next_credit_reset_date + from services.credit_pool_service import CreditPoolService paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid") diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 737202f8de..aec8efd880 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -619,8 +619,13 @@ class TestTenantService: mock_tenant_instance.name = "Test User's Workspace" mock_tenant_class.return_value = mock_tenant_instance - # Execute test - TenantService.create_owner_tenant_if_not_exist(mock_account) + # Mock the db import in CreditPoolService to avoid database connection + with patch("services.credit_pool_service.db") as mock_credit_pool_db: + mock_credit_pool_db.session.add = MagicMock() + mock_credit_pool_db.session.commit = MagicMock() + + # Execute test + TenantService.create_owner_tenant_if_not_exist(mock_account) # Verify tenant was created with correct parameters mock_db_dependencies["db"].session.add.assert_called()