dify/api/services/retention/conversation/messages_clean_policy.py

217 lines
6.9 KiB
Python

import datetime
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
@dataclass
class SimpleMessage:
id: str
app_id: str
created_at: datetime.datetime
class MessagesCleanPolicy(ABC):
"""
Abstract base class for message cleanup policies.
A policy determines which messages from a batch should be deleted.
"""
@abstractmethod
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages and return IDs of messages that should be deleted.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
...
class BillingDisabledPolicy(MessagesCleanPolicy):
"""
Policy for community or enterpriseedition (billing disabled).
No special filter logic, just return all message ids.
"""
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
return [msg.id for msg in messages]
class BillingSandboxPolicy(MessagesCleanPolicy):
"""
Policy for sandbox plan tenants in cloud edition (billing enabled).
Filters messages based on sandbox plan expiration rules:
- Skip tenants in the whitelist
- Only delete messages from sandbox plan tenants
- Respect grace period after subscription expiration
- Safe default: if tenant mapping or plan is missing, do NOT delete
"""
def __init__(
self,
plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]],
graceful_period_days: int = 21,
tenant_whitelist: Sequence[str] | None = None,
current_timestamp: int | None = None,
) -> None:
self._graceful_period_days = graceful_period_days
self._tenant_whitelist: Sequence[str] = tenant_whitelist or []
self._plan_provider = plan_provider
self._current_timestamp = current_timestamp
def filter_message_ids(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
) -> Sequence[str]:
"""
Filter messages based on sandbox plan expiration rules.
Args:
messages: Batch of messages to evaluate
app_to_tenant: Mapping from app_id to tenant_id
Returns:
List of message IDs that should be deleted
"""
if not messages or not app_to_tenant:
return []
# Get unique tenant_ids and fetch subscription plans
tenant_ids = list(set(app_to_tenant.values()))
tenant_plans = self._plan_provider(tenant_ids)
if not tenant_plans:
return []
# Apply sandbox deletion rules
return self._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
)
def _filter_expired_sandbox_messages(
self,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
tenant_plans: dict[str, SubscriptionPlan],
) -> list[str]:
"""
Filter messages that should be deleted based on sandbox plan expiration.
A message should be deleted if:
1. It belongs to a sandbox tenant AND
2. Either:
a) The tenant has no previous subscription (expiration_date == -1), OR
b) The subscription expired more than graceful_period_days ago
Args:
messages: List of message objects with id and app_id attributes
app_to_tenant: Mapping from app_id to tenant_id
tenant_plans: Mapping from tenant_id to subscription plan info
Returns:
List of message IDs that should be deleted
"""
current_timestamp = self._current_timestamp
if current_timestamp is None:
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
sandbox_message_ids: list[str] = []
graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60
for msg in messages:
# Get tenant_id for this message's app
tenant_id = app_to_tenant.get(msg.app_id)
if not tenant_id:
continue
# Skip tenant messages in whitelist
if tenant_id in self._tenant_whitelist:
continue
# Get subscription plan for this tenant
tenant_plan = tenant_plans.get(tenant_id)
if not tenant_plan:
continue
plan = str(tenant_plan["plan"])
expiration_date = int(tenant_plan["expiration_date"])
# Only process sandbox plans
if plan != CloudPlan.SANDBOX:
continue
# Case 1: No previous subscription (-1 means never had a paid subscription)
if expiration_date == -1:
sandbox_message_ids.append(msg.id)
continue
# Case 2: Subscription expired beyond grace period
if current_timestamp - expiration_date > graceful_period_seconds:
sandbox_message_ids.append(msg.id)
return sandbox_message_ids
def create_message_clean_policy(
graceful_period_days: int = 21,
current_timestamp: int | None = None,
) -> MessagesCleanPolicy:
"""
Factory function to create the appropriate message clean policy.
Determines which policy to use based on BILLING_ENABLED configuration:
- If BILLING_ENABLED is True: returns BillingSandboxPolicy
- If BILLING_ENABLED is False: returns BillingDisabledPolicy
Args:
graceful_period_days: Grace period in days after subscription expiration (default: 21)
current_timestamp: Current Unix timestamp for testing (default: None, uses current time)
"""
if not dify_config.BILLING_ENABLED:
logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy")
return BillingDisabledPolicy()
# Billing enabled - fetch whitelist from BillingService
tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist()
plan_provider = BillingService.get_plan_bulk_with_cache
logger.info(
"create_message_clean_policy: billing enabled, using BillingSandboxPolicy "
"(graceful_period_days=%s, whitelist=%s)",
graceful_period_days,
tenant_whitelist,
)
return BillingSandboxPolicy(
plan_provider=plan_provider,
graceful_period_days=graceful_period_days,
tenant_whitelist=tenant_whitelist,
current_timestamp=current_timestamp,
)