diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 2e50077b46..16e5de3d4c 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -647,6 +647,13 @@ class BillingConfig(BaseSettings): default=False, ) + BILLING_FREE_PLAN_GRACE_PERIOD_DAYS: NonNegativeInt = Field( + description=( + "Extra grace period in days applied after a tenant leaves a paid plan before being treated as free." + ), + default=21, + ) + class UpdateConfig(BaseSettings): """ diff --git a/api/services/billing_service.py b/api/services/billing_service.py index cd7b5fc389..b449ada26f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from typing import Literal import httpx +from pydantic import BaseModel, ValidationError from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from werkzeug.exceptions import InternalServerError @@ -16,6 +17,11 @@ from models import Account, TenantAccountJoin, TenantAccountRole logger = logging.getLogger(__name__) +class TenantPlanInfo(BaseModel): + plan: CloudPlan + expiration_date: int + + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @@ -30,32 +36,46 @@ class BillingService: return billing_info @classmethod - def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, str]: + def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, TenantPlanInfo]: """ Bulk billing info fetch via billing API. Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) Returns: - Mapping of tenant_id -> plan + Mapping of tenant_id -> TenantPlanInfo(plan + expiration timestamp) """ - results: dict[str, str] = {} + results: dict[str, TenantPlanInfo] = {} chunk_size = 200 for i in range(0, len(tenant_ids), chunk_size): chunk = tenant_ids[i : i + chunk_size] try: resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) - data = resp.get("data", {}) - for tenant_id, plan in data.items(): - if isinstance(plan, str): - results[tenant_id] = plan + results.update(cls._parse_bulk_response(chunk, resp)) except Exception: logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) - continue + raise return results + @classmethod + def _parse_bulk_response(cls, expected_ids: Sequence[str], response: dict) -> dict[str, TenantPlanInfo]: + data = response.get("data") + if not isinstance(data, dict): + raise ValueError("Billing API response missing 'data' object.") + + parsed: dict[str, TenantPlanInfo] = {} + for tenant_id in expected_ids: + payload = data.get(tenant_id) + + try: + parsed[tenant_id] = TenantPlanInfo.model_validate(payload) + except ValidationError as exc: + raise ValueError(f"Invalid billing info for tenant {tenant_id}") from exc + + return parsed + @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} diff --git a/api/services/clear_free_plan_expired_workflow_run_logs.py b/api/services/clear_free_plan_expired_workflow_run_logs.py index 55603795b8..c3fbd6600a 100644 --- a/api/services/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/clear_free_plan_expired_workflow_run_logs.py @@ -14,7 +14,7 @@ from repositories.sqlalchemy_api_workflow_node_execution_repository import ( DifyAPISQLAlchemyWorkflowNodeExecutionRepository, ) from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository -from services.billing_service import BillingService +from services.billing_service import BillingService, TenantPlanInfo logger = logging.getLogger(__name__) @@ -40,8 +40,9 @@ class WorkflowRunCleanup: raise ValueError("end_before must be greater than start_after.") self.batch_size = batch_size - self.billing_cache: dict[str, CloudPlan | None] = {} + self.billing_cache: dict[str, TenantPlanInfo | None] = {} self.dry_run = dry_run + self.free_plan_grace_period_days = dify_config.BILLING_FREE_PLAN_GRACE_PERIOD_DAYS self.workflow_run_repo: APIWorkflowRunRepository if workflow_run_repo: self.workflow_run_repo = workflow_run_repo @@ -157,10 +158,14 @@ class WorkflowRunCleanup: click.echo(click.style(summary_message, fg=summary_color)) def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: - if not dify_config.BILLING_ENABLED: - return set(tenant_ids) - tenant_id_list = list(tenant_ids) + + if not dify_config.BILLING_ENABLED: + return set(tenant_id_list) + + if not tenant_id_list: + return set() + uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache] if uncached_tenants: @@ -171,19 +176,47 @@ class WorkflowRunCleanup: logger.exception("Failed to fetch billing plans in bulk for tenants: %s", uncached_tenants) for tenant_id in uncached_tenants: - plan: CloudPlan | None = None info = bulk_info.get(tenant_id) - if info: - try: - plan = CloudPlan(info) - except Exception: - logger.exception("Failed to parse billing plan for tenant %s", tenant_id) - else: + if info is None: logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id) + self.billing_cache[tenant_id] = info - self.billing_cache[tenant_id] = plan + eligible_free_tenants: set[str] = set() + for tenant_id in tenant_id_list: + info = self.billing_cache.get(tenant_id) + if not info: + continue - return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX} + if info.plan != CloudPlan.SANDBOX: + continue + + if self._is_within_grace_period(tenant_id, info): + continue + + eligible_free_tenants.add(tenant_id) + + return eligible_free_tenants + + def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None: + if expiration_value < 0: + return None + + try: + return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC) + except (OverflowError, OSError, ValueError): + logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id) + return None + + def _is_within_grace_period(self, tenant_id: str, info: TenantPlanInfo) -> bool: + if self.free_plan_grace_period_days <= 0: + return False + + expiration_at = self._expiration_datetime(tenant_id, info.expiration_date) + if expiration_at is None: + return False + + grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days) + return datetime.datetime.now(datetime.UTC) < grace_deadline def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py index 24a07e7937..66cd7ff8c9 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -4,6 +4,7 @@ from typing import Any import pytest from services import clear_free_plan_expired_workflow_run_logs as cleanup_module +from services.billing_service import TenantPlanInfo from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup @@ -62,7 +63,18 @@ class FakeRepo: return result -def create_cleanup(monkeypatch: pytest.MonkeyPatch, repo: FakeRepo, **kwargs: Any) -> WorkflowRunCleanup: +def plan_info(plan: str, expiration: int) -> TenantPlanInfo: + return TenantPlanInfo(plan=plan, expiration_date=expiration) + + +def create_cleanup( + monkeypatch: pytest.MonkeyPatch, + repo: FakeRepo, + *, + grace_period_days: int = 0, + **kwargs: Any, +) -> WorkflowRunCleanup: + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_FREE_PLAN_GRACE_PERIOD_DAYS", grace_period_days) return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs) @@ -71,7 +83,7 @@ def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) - monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) - def fail_bulk(_: list[str]) -> dict[str, dict[str, Any]]: + def fail_bulk(_: list[str]) -> dict[str, TenantPlanInfo]: raise RuntimeError("should not call") monkeypatch.setattr(cleanup_module.BillingService, "get_info_bulk", staticmethod(fail_bulk)) @@ -86,12 +98,12 @@ def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) - cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX - cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM + cleanup.billing_cache["t_free"] = plan_info("sandbox", -1) + cleanup.billing_cache["t_paid"] = plan_info("team", -1) monkeypatch.setattr( cleanup_module.BillingService, "get_info_bulk", - staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "sandbox")), + staticmethod(lambda tenant_ids: {tenant_id: plan_info("sandbox", -1) for tenant_id in tenant_ids}), ) free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"}) @@ -99,6 +111,27 @@ def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None assert free == {"t_free", "t_missing"} +def test_filter_free_tenants_respects_grace_period(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, grace_period_days=45) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + now = datetime.datetime.now(datetime.UTC) + within_grace_ts = int((now - datetime.timedelta(days=10)).timestamp()) + outside_grace_ts = int((now - datetime.timedelta(days=90)).timestamp()) + + def fake_bulk(_: list[str]) -> dict[str, TenantPlanInfo]: + return { + "recently_downgraded": plan_info("sandbox", within_grace_ts), + "long_sandbox": plan_info("sandbox", outside_grace_ts), + } + + monkeypatch.setattr(cleanup_module.BillingService, "get_info_bulk", staticmethod(fake_bulk)) + + free = cleanup._filter_free_tenants({"recently_downgraded", "long_sandbox"}) + + assert free == {"long_sandbox"} + + def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None: cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) @@ -127,12 +160,12 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) - cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX - cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM + cleanup.billing_cache["t_free"] = plan_info("sandbox", -1) + cleanup.billing_cache["t_paid"] = plan_info("team", -1) monkeypatch.setattr( cleanup_module.BillingService, "get_info_bulk", - staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "sandbox")), + staticmethod(lambda tenant_ids: {tenant_id: plan_info("sandbox", -1) for tenant_id in tenant_ids}), ) cleanup.run() @@ -149,7 +182,7 @@ def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None monkeypatch.setattr( cleanup_module.BillingService, "get_info_bulk", - staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "team")), + staticmethod(lambda tenant_ids: {tenant_id: plan_info("team", 1893456000) for tenant_id in tenant_ids}), ) cleanup.run()