mirror of
https://github.com/langgenius/dify.git
synced 2026-04-24 00:59:19 +08:00
add grace_deadline logic
This commit is contained in:
parent
f5952b3884
commit
c42e7c8a97
@ -647,6 +647,13 @@ class BillingConfig(BaseSettings):
|
|||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BILLING_FREE_PLAN_GRACE_PERIOD_DAYS: NonNegativeInt = Field(
|
||||||
|
description=(
|
||||||
|
"Extra grace period in days applied after a tenant leaves a paid plan before being treated as free."
|
||||||
|
),
|
||||||
|
default=21,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UpdateConfig(BaseSettings):
|
class UpdateConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from collections.abc import Sequence
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
|
||||||
@ -16,6 +17,11 @@ from models import Account, TenantAccountJoin, TenantAccountRole
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantPlanInfo(BaseModel):
|
||||||
|
plan: CloudPlan
|
||||||
|
expiration_date: int
|
||||||
|
|
||||||
|
|
||||||
class BillingService:
|
class BillingService:
|
||||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||||
@ -30,32 +36,46 @@ class BillingService:
|
|||||||
return billing_info
|
return billing_info
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, str]:
|
def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, TenantPlanInfo]:
|
||||||
"""
|
"""
|
||||||
Bulk billing info fetch via billing API.
|
Bulk billing info fetch via billing API.
|
||||||
|
|
||||||
Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
|
Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Mapping of tenant_id -> plan
|
Mapping of tenant_id -> TenantPlanInfo(plan + expiration timestamp)
|
||||||
"""
|
"""
|
||||||
results: dict[str, str] = {}
|
results: dict[str, TenantPlanInfo] = {}
|
||||||
|
|
||||||
chunk_size = 200
|
chunk_size = 200
|
||||||
for i in range(0, len(tenant_ids), chunk_size):
|
for i in range(0, len(tenant_ids), chunk_size):
|
||||||
chunk = tenant_ids[i : i + chunk_size]
|
chunk = tenant_ids[i : i + chunk_size]
|
||||||
try:
|
try:
|
||||||
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
|
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
|
||||||
data = resp.get("data", {})
|
results.update(cls._parse_bulk_response(chunk, resp))
|
||||||
for tenant_id, plan in data.items():
|
|
||||||
if isinstance(plan, str):
|
|
||||||
results[tenant_id] = plan
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
||||||
continue
|
raise
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_bulk_response(cls, expected_ids: Sequence[str], response: dict) -> dict[str, TenantPlanInfo]:
|
||||||
|
data = response.get("data")
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("Billing API response missing 'data' object.")
|
||||||
|
|
||||||
|
parsed: dict[str, TenantPlanInfo] = {}
|
||||||
|
for tenant_id in expected_ids:
|
||||||
|
payload = data.get(tenant_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed[tenant_id] = TenantPlanInfo.model_validate(payload)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise ValueError(f"Invalid billing info for tenant {tenant_id}") from exc
|
||||||
|
|
||||||
|
return parsed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
def get_tenant_feature_plan_usage_info(cls, tenant_id: str):
|
||||||
params = {"tenant_id": tenant_id}
|
params = {"tenant_id": tenant_id}
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
|||||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||||
)
|
)
|
||||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService, TenantPlanInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -40,8 +40,9 @@ class WorkflowRunCleanup:
|
|||||||
raise ValueError("end_before must be greater than start_after.")
|
raise ValueError("end_before must be greater than start_after.")
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.billing_cache: dict[str, CloudPlan | None] = {}
|
self.billing_cache: dict[str, TenantPlanInfo | None] = {}
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
|
self.free_plan_grace_period_days = dify_config.BILLING_FREE_PLAN_GRACE_PERIOD_DAYS
|
||||||
self.workflow_run_repo: APIWorkflowRunRepository
|
self.workflow_run_repo: APIWorkflowRunRepository
|
||||||
if workflow_run_repo:
|
if workflow_run_repo:
|
||||||
self.workflow_run_repo = workflow_run_repo
|
self.workflow_run_repo = workflow_run_repo
|
||||||
@ -157,10 +158,14 @@ class WorkflowRunCleanup:
|
|||||||
click.echo(click.style(summary_message, fg=summary_color))
|
click.echo(click.style(summary_message, fg=summary_color))
|
||||||
|
|
||||||
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
||||||
if not dify_config.BILLING_ENABLED:
|
|
||||||
return set(tenant_ids)
|
|
||||||
|
|
||||||
tenant_id_list = list(tenant_ids)
|
tenant_id_list = list(tenant_ids)
|
||||||
|
|
||||||
|
if not dify_config.BILLING_ENABLED:
|
||||||
|
return set(tenant_id_list)
|
||||||
|
|
||||||
|
if not tenant_id_list:
|
||||||
|
return set()
|
||||||
|
|
||||||
uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache]
|
uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache]
|
||||||
|
|
||||||
if uncached_tenants:
|
if uncached_tenants:
|
||||||
@ -171,19 +176,47 @@ class WorkflowRunCleanup:
|
|||||||
logger.exception("Failed to fetch billing plans in bulk for tenants: %s", uncached_tenants)
|
logger.exception("Failed to fetch billing plans in bulk for tenants: %s", uncached_tenants)
|
||||||
|
|
||||||
for tenant_id in uncached_tenants:
|
for tenant_id in uncached_tenants:
|
||||||
plan: CloudPlan | None = None
|
|
||||||
info = bulk_info.get(tenant_id)
|
info = bulk_info.get(tenant_id)
|
||||||
if info:
|
if info is None:
|
||||||
try:
|
|
||||||
plan = CloudPlan(info)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to parse billing plan for tenant %s", tenant_id)
|
|
||||||
else:
|
|
||||||
logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id)
|
logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id)
|
||||||
|
self.billing_cache[tenant_id] = info
|
||||||
|
|
||||||
self.billing_cache[tenant_id] = plan
|
eligible_free_tenants: set[str] = set()
|
||||||
|
for tenant_id in tenant_id_list:
|
||||||
|
info = self.billing_cache.get(tenant_id)
|
||||||
|
if not info:
|
||||||
|
continue
|
||||||
|
|
||||||
return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX}
|
if info.plan != CloudPlan.SANDBOX:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._is_within_grace_period(tenant_id, info):
|
||||||
|
continue
|
||||||
|
|
||||||
|
eligible_free_tenants.add(tenant_id)
|
||||||
|
|
||||||
|
return eligible_free_tenants
|
||||||
|
|
||||||
|
def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None:
|
||||||
|
if expiration_value < 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC)
|
||||||
|
except (OverflowError, OSError, ValueError):
|
||||||
|
logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_within_grace_period(self, tenant_id: str, info: TenantPlanInfo) -> bool:
|
||||||
|
if self.free_plan_grace_period_days <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
expiration_at = self._expiration_datetime(tenant_id, info.expiration_date)
|
||||||
|
if expiration_at is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days)
|
||||||
|
return datetime.datetime.now(datetime.UTC) < grace_deadline
|
||||||
|
|
||||||
def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
|
def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
|
||||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from services import clear_free_plan_expired_workflow_run_logs as cleanup_module
|
from services import clear_free_plan_expired_workflow_run_logs as cleanup_module
|
||||||
|
from services.billing_service import TenantPlanInfo
|
||||||
from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +63,18 @@ class FakeRepo:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def create_cleanup(monkeypatch: pytest.MonkeyPatch, repo: FakeRepo, **kwargs: Any) -> WorkflowRunCleanup:
|
def plan_info(plan: str, expiration: int) -> TenantPlanInfo:
|
||||||
|
return TenantPlanInfo(plan=plan, expiration_date=expiration)
|
||||||
|
|
||||||
|
|
||||||
|
def create_cleanup(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
repo: FakeRepo,
|
||||||
|
*,
|
||||||
|
grace_period_days: int = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> WorkflowRunCleanup:
|
||||||
|
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_FREE_PLAN_GRACE_PERIOD_DAYS", grace_period_days)
|
||||||
return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs)
|
return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +83,7 @@ def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -
|
|||||||
|
|
||||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||||
|
|
||||||
def fail_bulk(_: list[str]) -> dict[str, dict[str, Any]]:
|
def fail_bulk(_: list[str]) -> dict[str, TenantPlanInfo]:
|
||||||
raise RuntimeError("should not call")
|
raise RuntimeError("should not call")
|
||||||
|
|
||||||
monkeypatch.setattr(cleanup_module.BillingService, "get_info_bulk", staticmethod(fail_bulk))
|
monkeypatch.setattr(cleanup_module.BillingService, "get_info_bulk", staticmethod(fail_bulk))
|
||||||
@ -86,12 +98,12 @@ def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None
|
|||||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||||
|
|
||||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||||
cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX
|
cleanup.billing_cache["t_free"] = plan_info("sandbox", -1)
|
||||||
cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM
|
cleanup.billing_cache["t_paid"] = plan_info("team", -1)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
cleanup_module.BillingService,
|
cleanup_module.BillingService,
|
||||||
"get_info_bulk",
|
"get_info_bulk",
|
||||||
staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "sandbox")),
|
staticmethod(lambda tenant_ids: {tenant_id: plan_info("sandbox", -1) for tenant_id in tenant_ids}),
|
||||||
)
|
)
|
||||||
|
|
||||||
free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"})
|
free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"})
|
||||||
@ -99,6 +111,27 @@ def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None
|
|||||||
assert free == {"t_free", "t_missing"}
|
assert free == {"t_free", "t_missing"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_free_tenants_respects_grace_period(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, grace_period_days=45)
|
||||||
|
|
||||||
|
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||||
|
now = datetime.datetime.now(datetime.UTC)
|
||||||
|
within_grace_ts = int((now - datetime.timedelta(days=10)).timestamp())
|
||||||
|
outside_grace_ts = int((now - datetime.timedelta(days=90)).timestamp())
|
||||||
|
|
||||||
|
def fake_bulk(_: list[str]) -> dict[str, TenantPlanInfo]:
|
||||||
|
return {
|
||||||
|
"recently_downgraded": plan_info("sandbox", within_grace_ts),
|
||||||
|
"long_sandbox": plan_info("sandbox", outside_grace_ts),
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(cleanup_module.BillingService, "get_info_bulk", staticmethod(fake_bulk))
|
||||||
|
|
||||||
|
free = cleanup._filter_free_tenants({"recently_downgraded", "long_sandbox"})
|
||||||
|
|
||||||
|
assert free == {"long_sandbox"}
|
||||||
|
|
||||||
|
|
||||||
def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||||
|
|
||||||
@ -127,12 +160,12 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||||
|
|
||||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||||
cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX
|
cleanup.billing_cache["t_free"] = plan_info("sandbox", -1)
|
||||||
cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM
|
cleanup.billing_cache["t_paid"] = plan_info("team", -1)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
cleanup_module.BillingService,
|
cleanup_module.BillingService,
|
||||||
"get_info_bulk",
|
"get_info_bulk",
|
||||||
staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "sandbox")),
|
staticmethod(lambda tenant_ids: {tenant_id: plan_info("sandbox", -1) for tenant_id in tenant_ids}),
|
||||||
)
|
)
|
||||||
|
|
||||||
cleanup.run()
|
cleanup.run()
|
||||||
@ -149,7 +182,7 @@ def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
cleanup_module.BillingService,
|
cleanup_module.BillingService,
|
||||||
"get_info_bulk",
|
"get_info_bulk",
|
||||||
staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "team")),
|
staticmethod(lambda tenant_ids: {tenant_id: plan_info("team", 1893456000) for tenant_id in tenant_ids}),
|
||||||
)
|
)
|
||||||
|
|
||||||
cleanup.run()
|
cleanup.run()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user