mirror of https://github.com/langgenius/dify.git
217 lines
6.9 KiB
Python
217 lines
6.9 KiB
Python
import datetime
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable, Sequence
|
|
from dataclasses import dataclass
|
|
|
|
from configs import dify_config
|
|
from enums.cloud_plan import CloudPlan
|
|
from services.billing_service import BillingService, SubscriptionPlan
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SimpleMessage:
|
|
id: str
|
|
app_id: str
|
|
created_at: datetime.datetime
|
|
|
|
|
|
class MessagesCleanPolicy(ABC):
|
|
"""
|
|
Abstract base class for message cleanup policies.
|
|
|
|
A policy determines which messages from a batch should be deleted.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def filter_message_ids(
|
|
self,
|
|
messages: Sequence[SimpleMessage],
|
|
app_to_tenant: dict[str, str],
|
|
) -> Sequence[str]:
|
|
"""
|
|
Filter messages and return IDs of messages that should be deleted.
|
|
|
|
Args:
|
|
messages: Batch of messages to evaluate
|
|
app_to_tenant: Mapping from app_id to tenant_id
|
|
|
|
Returns:
|
|
List of message IDs that should be deleted
|
|
"""
|
|
...
|
|
|
|
|
|
class BillingDisabledPolicy(MessagesCleanPolicy):
|
|
"""
|
|
Policy for community or enterpriseedition (billing disabled).
|
|
|
|
No special filter logic, just return all message ids.
|
|
"""
|
|
|
|
def filter_message_ids(
|
|
self,
|
|
messages: Sequence[SimpleMessage],
|
|
app_to_tenant: dict[str, str],
|
|
) -> Sequence[str]:
|
|
return [msg.id for msg in messages]
|
|
|
|
|
|
class BillingSandboxPolicy(MessagesCleanPolicy):
|
|
"""
|
|
Policy for sandbox plan tenants in cloud edition (billing enabled).
|
|
|
|
Filters messages based on sandbox plan expiration rules:
|
|
- Skip tenants in the whitelist
|
|
- Only delete messages from sandbox plan tenants
|
|
- Respect grace period after subscription expiration
|
|
- Safe default: if tenant mapping or plan is missing, do NOT delete
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
|
|
graceful_period_days: int = 21,
|
|
tenant_whitelist: Sequence[str] | None = None,
|
|
current_timestamp: int | None = None,
|
|
) -> None:
|
|
self._graceful_period_days = graceful_period_days
|
|
self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
|
|
self._plan_provider = plan_provider
|
|
self._current_timestamp = current_timestamp
|
|
|
|
def filter_message_ids(
|
|
self,
|
|
messages: Sequence[SimpleMessage],
|
|
app_to_tenant: dict[str, str],
|
|
) -> Sequence[str]:
|
|
"""
|
|
Filter messages based on sandbox plan expiration rules.
|
|
|
|
Args:
|
|
messages: Batch of messages to evaluate
|
|
app_to_tenant: Mapping from app_id to tenant_id
|
|
|
|
Returns:
|
|
List of message IDs that should be deleted
|
|
"""
|
|
if not messages or not app_to_tenant:
|
|
return []
|
|
|
|
# Get unique tenant_ids and fetch subscription plans
|
|
tenant_ids = list(set(app_to_tenant.values()))
|
|
tenant_plans = self._plan_provider(tenant_ids)
|
|
|
|
if not tenant_plans:
|
|
return []
|
|
|
|
# Apply sandbox deletion rules
|
|
return self._filter_expired_sandbox_messages(
|
|
messages=messages,
|
|
app_to_tenant=app_to_tenant,
|
|
tenant_plans=tenant_plans,
|
|
)
|
|
|
|
def _filter_expired_sandbox_messages(
|
|
self,
|
|
messages: Sequence[SimpleMessage],
|
|
app_to_tenant: dict[str, str],
|
|
tenant_plans: dict[str, SubscriptionPlan],
|
|
) -> list[str]:
|
|
"""
|
|
Filter messages that should be deleted based on sandbox plan expiration.
|
|
|
|
A message should be deleted if:
|
|
1. It belongs to a sandbox tenant AND
|
|
2. Either:
|
|
a) The tenant has no previous subscription (expiration_date == -1), OR
|
|
b) The subscription expired more than graceful_period_days ago
|
|
|
|
Args:
|
|
messages: List of message objects with id and app_id attributes
|
|
app_to_tenant: Mapping from app_id to tenant_id
|
|
tenant_plans: Mapping from tenant_id to subscription plan info
|
|
|
|
Returns:
|
|
List of message IDs that should be deleted
|
|
"""
|
|
current_timestamp = self._current_timestamp
|
|
if current_timestamp is None:
|
|
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
|
|
|
sandbox_message_ids: list[str] = []
|
|
graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
|
|
|
|
for msg in messages:
|
|
# Get tenant_id for this message's app
|
|
tenant_id = app_to_tenant.get(msg.app_id)
|
|
if not tenant_id:
|
|
continue
|
|
|
|
# Skip tenant messages in whitelist
|
|
if tenant_id in self._tenant_whitelist:
|
|
continue
|
|
|
|
# Get subscription plan for this tenant
|
|
tenant_plan = tenant_plans.get(tenant_id)
|
|
if not tenant_plan:
|
|
continue
|
|
|
|
plan = str(tenant_plan["plan"])
|
|
expiration_date = int(tenant_plan["expiration_date"])
|
|
|
|
# Only process sandbox plans
|
|
if plan != CloudPlan.SANDBOX:
|
|
continue
|
|
|
|
# Case 1: No previous subscription (-1 means never had a paid subscription)
|
|
if expiration_date == -1:
|
|
sandbox_message_ids.append(msg.id)
|
|
continue
|
|
|
|
# Case 2: Subscription expired beyond grace period
|
|
if current_timestamp - expiration_date > graceful_period_seconds:
|
|
sandbox_message_ids.append(msg.id)
|
|
|
|
return sandbox_message_ids
|
|
|
|
|
|
def create_message_clean_policy(
|
|
graceful_period_days: int = 21,
|
|
current_timestamp: int | None = None,
|
|
) -> MessagesCleanPolicy:
|
|
"""
|
|
Factory function to create the appropriate message clean policy.
|
|
|
|
Determines which policy to use based on BILLING_ENABLED configuration:
|
|
- If BILLING_ENABLED is True: returns BillingSandboxPolicy
|
|
- If BILLING_ENABLED is False: returns BillingDisabledPolicy
|
|
|
|
Args:
|
|
graceful_period_days: Grace period in days after subscription expiration (default: 21)
|
|
current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
|
|
"""
|
|
if not dify_config.BILLING_ENABLED:
|
|
logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
|
|
return BillingDisabledPolicy()
|
|
|
|
# Billing enabled - fetch whitelist from BillingService
|
|
tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
|
|
plan_provider = BillingService.get_plan_bulk_with_cache
|
|
|
|
logger.info(
|
|
"create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
|
|
"(graceful_period_days=%s, whitelist=%s)",
|
|
graceful_period_days,
|
|
tenant_whitelist,
|
|
)
|
|
|
|
return BillingSandboxPolicy(
|
|
plan_provider=plan_provider,
|
|
graceful_period_days=graceful_period_days,
|
|
tenant_whitelist=tenant_whitelist,
|
|
current_timestamp=current_timestamp,
|
|
)
|