From c3852833562212c074624a8a2ca97dd85cff00b1 Mon Sep 17 00:00:00 2001 From: hj24 Date: Tue, 13 Jan 2026 23:14:31 +0800 Subject: [PATCH] refactor: enhance clean message task --- api/.env.example | 1 + api/commands.py | 79 ++ api/extensions/ext_commands.py | 2 + ...eat_add_created_at_id_index_to_messages.py | 33 + api/models/model.py | 1 + api/schedule/clean_messages.py | 126 +- .../conversation/messages_clean_policy.py | 214 ++++ .../conversation/messages_clean_service.py | 325 +++++ .../services/test_messages_clean_service.py | 1090 +++++++++++++++++ .../services/test_messages_clean_service.py | 627 ++++++++++ 10 files changed, 2421 insertions(+), 77 deletions(-) create mode 100644 api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py create mode 100644 api/services/retention/conversation/messages_clean_policy.py create mode 100644 api/services/retention/conversation/messages_clean_service.py create mode 100644 api/tests/test_containers_integration_tests/services/test_messages_clean_service.py create mode 100644 api/tests/unit_tests/services/test_messages_clean_service.py diff --git a/api/.env.example b/api/.env.example index 8099c4a42a..9da1900dc7 100644 --- a/api/.env.example +++ b/api/.env.example @@ -709,6 +709,7 @@ ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5 ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20 # Maximum number of concurrent annotation import tasks per tenant ANNOTATION_IMPORT_MAX_CONCURRENT=5 + # Sandbox expired records clean configuration SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21 SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000 diff --git a/api/commands.py b/api/commands.py index e24b1826ee..d0949fe0d6 100644 --- a/api/commands.py +++ b/api/commands.py @@ -3,6 +3,7 @@ import datetime import json import logging import secrets +import time from typing import Any import click @@ -46,6 +47,8 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_service import PluginService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from tasks.remove_app_and_related_data_task import delete_draft_variables_batch @@ -2168,3 +2171,79 @@ def migrate_oss( except Exception as e: db.session.rollback() click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) + + +@click.command("clean-expired-messages", help="Clean expired messages.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Upper bound (exclusive) for created_at.", +) +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") +@click.option( + "--graceful-period", + default=21, + show_default=True, + help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", +) +@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleteing") +def clean_expired_messages( + batch_size: int, + graceful_period: int, + start_from: datetime.datetime, + end_before: datetime.datetime, + dry_run: bool, +): + """ + Clean expired messages and related data for tenants based on clean policy. + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + # NOTE: graceful_period will be ignored when billing is disabled. + policy = create_message_clean_policy(graceful_period_days=graceful_period) + + # Create and run the cleanup service + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise + + click.echo(click.style("messages cleanup completed.", fg="green")) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index c32130d377..51e2c6cdd5 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_expired_messages, clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, @@ -58,6 +59,7 @@ def init_app(app: DifyApp): transform_datasource_credentials, install_rag_pipeline_plugins, clean_workflow_runs, + clean_expired_messages, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py new file mode 100644 index 0000000000..758369ba99 --- /dev/null +++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py @@ -0,0 +1,33 @@ +"""feat: add created_at id index to messages + +Revision ID: 3334862ee907 +Revises: 905527cc8fd3 +Create Date: 2026-01-12 17:29:44.846544 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3334862ee907' +down_revision = '905527cc8fd3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_created_at_id_idx') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index a48f4d34d4..603b16b249 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -968,6 +968,7 @@ class Message(Base): Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), Index("message_created_at_idx", "created_at"), Index("message_app_mode_idx", "app_mode"), + Index("message_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 352a84b592..e85bba8823 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -1,90 +1,62 @@ -import datetime import logging import time import click -from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config -from enums.cloud_plan import CloudPlan -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.model import ( - App, - Message, - MessageAgentThought, - MessageAnnotation, - MessageChain, - MessageFeedback, - MessageFile, -) -from models.web import SavedMessage -from services.feature_service import FeatureService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService logger = logging.getLogger(__name__) -@app.celery.task(queue="dataset") +@app.celery.task(queue="retention") def clean_messages(): - click.echo(click.style("Start clean messages.", fg="green")) - start_at = time.perf_counter() - plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( - days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING - ) - while True: - try: - # Main query with join and filter - messages = ( - db.session.query(Message) - .where(Message.created_at < plan_sandbox_clean_message_day) - .order_by(Message.created_at.desc()) - .limit(100) - .all() - ) + """ + Clean expired messages based on clean policy. - except SQLAlchemyError: - raise - if not messages: - break - for message in messages: - app = db.session.query(App).filter_by(id=message.app_id).first() - if not app: - logger.warning( - "Expected App record to exist, but none was found, app_id=%s, message_id=%s", - message.app_id, - message.id, - ) - continue - features_cache_key = f"features:{app.tenant_id}" - plan_cache = redis_client.get(features_cache_key) - if plan_cache is None: - features = FeatureService.get_features(app.tenant_id) - redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) - plan = features.billing.subscription.plan - else: - plan = plan_cache.decode() - if plan == CloudPlan.SANDBOX: - # clean related message - db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(Message).where(Message.id == message.id).delete() - db.session.commit() - end_at = time.perf_counter() - click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green")) + This task uses MessagesCleanService to efficiently clean messages in batches. + The behavior depends on BILLING_ENABLED configuration: + - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period) + - BILLING_ENABLED=False: delete all messages within the time range + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + policy = create_message_clean_policy( + graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD, + ) + + # Create and run the cleanup service + service = MessagesCleanService.from_days( + policy=policy, + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py new file mode 100644 index 0000000000..d36d1af00b --- /dev/null +++ b/api/services/retention/conversation/messages_clean_policy.py @@ -0,0 +1,214 @@ +import datetime +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +@dataclass +class SimpleMessage: + id: str + app_id: str + created_at: datetime.datetime + + +class MessagesCleanPolicy(ABC): + """ + Abstract base class for message cleanup policies. + + A policy determines which messages from a batch should be deleted. + """ + + @abstractmethod + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages and return IDs of messages that should be deleted. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + ... + + +class BillingDisabledPolicy(MessagesCleanPolicy): + """ + Policy for community or enterpriseedition (billing disabled). + + No special filter logic, just return all message ids. + """ + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + return [msg.id for msg in messages] + + +class BillingSandboxPolicy(MessagesCleanPolicy): + """ + Policy for sandbox plan tenants in cloud edition (billing enabled). + + Filters messages based on sandbox plan expiration rules: + - Skip tenants in the whitelist + - Only delete messages from sandbox plan tenants + - Respect grace period after subscription expiration + - Safe default: if tenant mapping or plan is missing, do NOT delete + """ + def __init__( + self, + plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]], + graceful_period_days: int = 21, + tenant_whitelist: Sequence[str] | None = None, + current_timestamp: int | None = None, + ) -> None: + self._graceful_period_days = graceful_period_days + self._tenant_whitelist: Sequence[str] = tenant_whitelist or [] + self._plan_provider = plan_provider + self._current_timestamp = current_timestamp + + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages based on sandbox plan expiration rules. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + if not messages or not app_to_tenant: + return [] + + # Get unique tenant_ids and fetch subscription plans + tenant_ids = list(set(app_to_tenant.values())) + tenant_plans = self._plan_provider(tenant_ids) + + if not tenant_plans: + return [] + + # Apply sandbox deletion rules + return self._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + ) + + def _filter_expired_sandbox_messages( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + tenant_plans: dict[str, SubscriptionPlan], + ) -> list[str]: + """ + Filter messages that should be deleted based on sandbox plan expiration. + + A message should be deleted if: + 1. It belongs to a sandbox tenant AND + 2. Either: + a) The tenant has no previous subscription (expiration_date == -1), OR + b) The subscription expired more than graceful_period_days ago + + Args: + messages: List of message objects with id and app_id attributes + app_to_tenant: Mapping from app_id to tenant_id + tenant_plans: Mapping from tenant_id to subscription plan info + + Returns: + List of message IDs that should be deleted + """ + current_timestamp = self._current_timestamp + if current_timestamp is None: + current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + sandbox_message_ids: list[str] = [] + graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60 + + for msg in messages: + # Get tenant_id for this message's app + tenant_id = app_to_tenant.get(msg.app_id) + if not tenant_id: + continue + + # Skip tenant messages in whitelist + if tenant_id in self._tenant_whitelist: + continue + + # Get subscription plan for this tenant + tenant_plan = tenant_plans.get(tenant_id) + if not tenant_plan: + continue + + plan = str(tenant_plan["plan"]) + expiration_date = int(tenant_plan["expiration_date"]) + + # Only process sandbox plans + if plan != CloudPlan.SANDBOX: + continue + + # Case 1: No previous subscription (-1 means never had a paid subscription) + if expiration_date == -1: + sandbox_message_ids.append(msg.id) + continue + + # Case 2: Subscription expired beyond grace period + if current_timestamp - expiration_date > graceful_period_seconds: + sandbox_message_ids.append(msg.id) + + return sandbox_message_ids + + +def create_message_clean_policy( + graceful_period_days: int = 21, + current_timestamp: int | None = None, +) -> MessagesCleanPolicy: + """ + Factory function to create the appropriate message clean policy. + + Determines which policy to use based on BILLING_ENABLED configuration: + - If BILLING_ENABLED is True: returns BillingSandboxPolicy + - If BILLING_ENABLED is False: returns BillingDisabledPolicy + + Args: + graceful_period_days: Grace period in days after subscription expiration (default: 21) + current_timestamp: Current Unix timestamp for testing (default: None, uses current time) + """ + if not dify_config.BILLING_ENABLED: + logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy") + return BillingDisabledPolicy() + + # Billing enabled - fetch whitelist from BillingService + tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist() + plan_provider = BillingService.get_plan_bulk_with_cache + + logger.info( + "create_message_clean_policy: billing enabled, using BillingSandboxPolicy " + "(graceful_period_days=%s, whitelist=%s)", + graceful_period_days, + tenant_whitelist, + ) + + return BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=graceful_period_days, + tenant_whitelist=tenant_whitelist, + current_timestamp=current_timestamp, + ) diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py new file mode 100644 index 0000000000..a28bc1ba67 --- /dev/null +++ b/api/services/retention/conversation/messages_clean_service.py @@ -0,0 +1,325 @@ +import datetime +import logging +import random +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.model import ( + App, + AppAnnotationHitHistory, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.retention.conversation.messages_clean_policy import ( + MessagesCleanPolicy, + SimpleMessage, +) + +logger = logging.getLogger(__name__) + + +class MessagesCleanService: + """ + Service for cleaning expired messages based on retention policies. + + Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted. + If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support). + """ + + def __init__( + self, + policy: MessagesCleanPolicy, + end_before: datetime.datetime, + start_from: datetime.datetime | None = None, + batch_size: int = 1000, + dry_run: bool = False, + ) -> None: + """ + Initialize the service with cleanup parameters. + + Args: + policy: The policy that determines which messages to delete + end_before: End time (exclusive) of the range + start_from: Optional start time (inclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + """ + self._policy = policy + self._end_before = end_before + self._start_from = start_from + self._batch_size = batch_size + self._dry_run = dry_run + + @classmethod + def from_time_range( + cls, + policy: MessagesCleanPolicy, + start_from: datetime.datetime, + end_before: datetime.datetime, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages within a specific time range. + + Time range is [start_from, end_before). + + Args: + policy: The policy that determines which messages to delete + start_from: Start time (inclusive) of the range + end_before: End time (exclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If start_from >= end_before or invalid parameters + """ + if start_from >= end_before: + raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + logger.info( + "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s", + start_from, end_before, batch_size, policy, + ) + + return cls( + policy=policy, + end_before=end_before, + start_from=start_from, + batch_size=batch_size, + dry_run=dry_run, + ) + + @classmethod + def from_days( + cls, + policy: MessagesCleanPolicy, + days: int = 30, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages older than specified days. + + Args: + policy: The policy that determines which messages to delete + days: Number of days to look back from now + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If invalid parameters + """ + if days < 0: + raise ValueError(f"days ({days}) must be greater than or equal to 0") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + end_before = datetime.datetime.now() - datetime.timedelta(days=days) + + logger.info( + "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s", + days, end_before, batch_size, policy, + ) + + return cls( + policy=policy, + end_before=end_before, + start_from=None, + batch_size=batch_size, + dry_run=dry_run + ) + + def run(self) -> dict[str, int]: + """ + Execute the message cleanup operation. + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + return self._clean_messages_by_time_range() + + def _clean_messages_by_time_range(self) -> dict[str, int]: + """ + Clean messages within a time range using cursor-based pagination. + + Time range is [start_from, end_before) + + Steps: + 1. Iterate messages using cursor pagination (by created_at, id) + 2. Query app_id -> tenant_id mapping + 3. Delegate to policy to determine which messages to delete + 4. Batch delete messages and their relations + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + stats = { + "batches": 0, + "total_messages": 0, + "filtered_messages": 0, + "total_deleted": 0, + } + + # Cursor-based pagination using (created_at, id) to avoid infinite loops + # and ensure proper ordering with time-based filtering + _cursor: tuple[datetime.datetime, str] | None = None + + logger.info( + "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s", + self._dry_run, self._start_from, self._end_before, + ) + + while True: + stats["batches"] += 1 + + # Step 1: Fetch a batch of messages using cursor + with Session(db.engine, expire_on_commit=False) as session: + msg_stmt = ( + select(Message.id, Message.app_id, Message.created_at) + .where(Message.created_at < self._end_before) + .order_by(Message.created_at, Message.id) + .limit(self._batch_size) + ) + + if self._start_from: + msg_stmt = msg_stmt.where(Message.created_at >= self._start_from) + + # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id) + # This translates to: + # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id) + if _cursor: + # Continuing from previous batch + msg_stmt = msg_stmt.where( + (Message.created_at > _cursor[0]) + | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1])) + ) + + raw_messages = list(session.execute(msg_stmt).all()) + messages = [ + SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at) + for msg_id, app_id, msg_created_at in raw_messages + ] + + # Track total messages fetched across all batches + stats["total_messages"] += len(messages) + + if not messages: + logger.info("clean_messages (batch %s): no more messages to process", stats["batches"]) + break + + # Update cursor to the last message's (created_at, id) + _cursor = (messages[-1].created_at, messages[-1].id) + + # Step 2: Extract app_ids and query tenant_ids + app_ids = list({msg.app_id for msg in messages}) + + if not app_ids: + logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"]) + continue + + app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids)) + apps = list(session.execute(app_stmt).all()) + + if not apps: + logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"]) + continue + + # Build app_id -> tenant_id mapping + app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps} + + # Step 3: Delegate to policy to determine which messages to delete + message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant) + + if not message_ids_to_delete: + logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"]) + continue + + stats["filtered_messages"] += len(message_ids_to_delete) + + # Step 4: Batch delete messages and their relations + if not self._dry_run: + with Session(db.engine, expire_on_commit=False) as session: + # Delete related records first + self._batch_delete_message_relations(session, message_ids_to_delete) + + # Delete messages + delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete)) + delete_result = cast(CursorResult, session.execute(delete_stmt)) + messages_deleted = delete_result.rowcount + session.commit() + + stats["total_deleted"] += messages_deleted + + logger.info( + "clean_messages (batch %s): processed %s messages, deleted %s messages", + stats["batches"], len(messages), messages_deleted, + ) + else: + # Log random sample of message IDs that would be deleted (up to 10) + sample_size = min(10, len(message_ids_to_delete)) + sampled_ids = random.sample(list(message_ids_to_delete), sample_size) + + logger.info( + "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:", + stats["batches"], len(message_ids_to_delete), sample_size, + ) + for msg_id in sampled_ids: + logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id) + + logger.info( + "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s", + stats["batches"], stats["total_messages"], stats["filtered_messages"], stats["total_deleted"], + ) + + return stats + + @staticmethod + def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None: + """ + Batch delete all related records for given message IDs. + + Args: + session: Database session + message_ids: List of message IDs to delete relations for + """ + if not message_ids: + return + + # Delete all related records in batch + session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids))) + + session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids))) + + session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids))) + + session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids))) + + session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids))) + + session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids))) + + session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids))) + + session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids))) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..3832da0688 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -0,0 +1,1090 @@ +import datetime +import json +import uuid +from decimal import Decimal +from unittest.mock import patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.billing_service import BillingService +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +class TestMessagesCleanServiceIntegration: + """Integration tests for MessagesCleanService.run() and _clean_messages_by_time_range().""" + + # Redis cache key prefix from BillingService + PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database before and after each test to ensure isolation.""" + yield + # Clear all test data in correct order (respecting foreign key constraints) + db.session.query(DatasetRetrieverResource).delete() + db.session.query(AppAnnotationHitHistory).delete() + db.session.query(SavedMessage).delete() + db.session.query(MessageFile).delete() + db.session.query(MessageAgentThought).delete() + db.session.query(MessageChain).delete() + db.session.query(MessageAnnotation).delete() + db.session.query(MessageFeedback).delete() + db.session.query(Message).delete() + db.session.query(Conversation).delete() + db.session.query(App).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + @pytest.fixture(autouse=True) + def cleanup_redis(self): + """Clean up Redis cache before each test.""" + # Clear tenant plan cache using BillingService key prefix + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass # Redis might not be available in some test environments + yield + # Clean up after test + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass + + @pytest.fixture + def mock_whitelist(self): + """Mock whitelist to return empty list by default.""" + with patch( + "services.retention.conversation.messages_clean_policy.BillingService.get_expired_subscription_cleanup_whitelist" + ) as mock: + mock.return_value = [] + yield mock + + @pytest.fixture + def mock_billing_enabled(self): + """Mock BILLING_ENABLED to be True.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", True): + yield + + @pytest.fixture + def mock_billing_disabled(self): + """Mock BILLING_ENABLED to be False.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): + yield + + def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + """Helper to create account and tenant.""" + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.flush() + + tenant = Tenant( + name=fake.company(), + plan=str(plan), + status="normal", + ) + db.session.add(tenant) + db.session.flush() + + tenant_account_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + db.session.add(tenant_account_join) + db.session.commit() + + return account, tenant + + def _create_app(self, tenant, account): + """Helper to create an app.""" + fake = Faker() + + app = App( + tenant_id=tenant.id, + name=fake.company(), + description="Test app", + mode="chat", + enable_site=True, + enable_api=True, + api_rpm=60, + api_rph=3600, + is_demo=False, + is_public=False, + created_by=account.id, + updated_by=account.id, + ) + db.session.add(app) + db.session.commit() + + return app + + def _create_conversation(self, app): + """Helper to create a conversation.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=str(uuid.uuid4()), + model_provider="openai", + model_id="gpt-3.5-turbo", + mode="chat", + name="Test conversation", + inputs={}, + status="normal", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(conversation) + db.session.commit() + + return conversation + + def _create_message(self, app, conversation, created_at=None, with_relations=True): + """Helper to create a message with optional related records.""" + if created_at is None: + created_at = datetime.datetime.now() + + message = Message( + app_id=app.id, + conversation_id=conversation.id, + model_provider="openai", + model_id="gpt-3.5-turbo", + inputs={}, + query="Test query", + answer="Test answer", + message=[{"role": "user", "text": "Test message"}], + message_tokens=10, + message_unit_price=Decimal("0.001"), + answer_tokens=20, + answer_unit_price=Decimal("0.002"), + total_price=Decimal("0.003"), + currency="USD", + from_source="api", + from_account_id=conversation.from_end_user_id, + created_at=created_at, + ) + db.session.add(message) + db.session.flush() + + if with_relations: + self._create_message_relations(message) + + db.session.commit() + return message + + def _create_message_relations(self, message): + """Helper to create all message-related records.""" + # MessageFeedback + feedback = MessageFeedback( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + rating="like", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(feedback) + + # MessageAnnotation + annotation = MessageAnnotation( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + question="Test question", + content="Test annotation", + account_id=message.from_account_id, + ) + db.session.add(annotation) + + # MessageChain + chain = MessageChain( + message_id=message.id, + type="system", + input=json.dumps({"test": "input"}), + output=json.dumps({"test": "output"}), + ) + db.session.add(chain) + db.session.flush() + + # MessageFile + file = MessageFile( + message_id=message.id, + type="image", + transfer_method="local_file", + url="http://example.com/test.jpg", + belongs_to="user", + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(file) + + # SavedMessage + saved = SavedMessage( + app_id=message.app_id, + message_id=message.id, + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(saved) + + db.session.flush() + + # AppAnnotationHitHistory + hit = AppAnnotationHitHistory( + app_id=message.app_id, + annotation_id=annotation.id, + message_id=message.id, + source="annotation", + question="Test question", + account_id=message.from_account_id, + annotation_question="Test annotation question", + annotation_content="Test annotation content", + ) + db.session.add(hit) + + # DatasetRetrieverResource + resource = DatasetRetrieverResource( + message_id=message.id, + position=1, + dataset_id=str(uuid.uuid4()), + dataset_name="Test dataset", + document_id=str(uuid.uuid4()), + document_name="Test document", + data_source_type="upload_file", + segment_id=str(uuid.uuid4()), + score=0.9, + content="Test content", + hit_count=1, + word_count=10, + segment_position=1, + index_node_hash="test_hash", + retriever_from="dataset", + created_by=message.from_account_id, + ) + db.session.add(resource) + + def test_billing_disabled_deletes_all_messages_in_time_range( + self, db_session_with_containers, mock_billing_disabled + ): + """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" + # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: in-range (should be deleted) and out-of-range (should be kept) + in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) + out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) + + in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg_id = in_range_msg.id + + out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg_id = out_of_range_msg.id + + # Act - create_message_clean_policy should return BillingDisabledPolicy + policy = create_message_clean_policy() + + assert isinstance(policy, BillingDisabledPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 0, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 0, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 1 # Only in-range message fetched + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # In-range message deleted + assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + # Out-of-range message kept + assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + + # Related records of in-range message deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 + assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + # Related records of out-of-range message kept + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + + def test_no_messages_returns_empty_stats( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test cleaning when there are no messages to delete (B1).""" + # Arrange + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + start_from = datetime.datetime.now() - datetime.timedelta(days=60) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - loop runs once to check, finds nothing + assert stats["batches"] == 1 + assert stats["total_messages"] == 0 + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_mixed_sandbox_and_paid_tenants( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test cleaning with mixed sandbox and paid tenants (B2).""" + # Arrange - Create sandbox tenants with expired messages + sandbox_tenants = [] + sandbox_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + sandbox_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 3 expired messages per sandbox tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + sandbox_message_ids.append(msg.id) + + # Create paid tenants with expired messages (should NOT be deleted) + paid_tenants = [] + paid_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + paid_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 2 expired messages per paid tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(2): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + paid_message_ids.append(msg.id) + + # Mock billing service - return plan and expiration_date + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period + + plan_map = {} + for tenant in sandbox_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_15_days_ago, + } + for tenant in paid_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=7) + + assert isinstance(policy, BillingSandboxPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 10 # 2 sandbox * 3 + 2 paid * 2 + assert stats["filtered_messages"] == 6 # 2 sandbox tenants * 3 messages + assert stats["total_deleted"] == 6 + + # Only sandbox messages should be deleted + assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + # Paid messages should remain + assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + + # Related records of sandbox messages should be deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 + assert ( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + == 0 + ) + + def test_cursor_pagination_multiple_batches( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test cursor pagination works correctly across multiple batches (B3).""" + # Arrange - Create sandbox tenant with messages that will span multiple batches + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 10 expired messages with different timestamps + base_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(10): + msg = self._create_message( + app, + conv, + created_at=base_date + datetime.timedelta(hours=i), + with_relations=False, # Skip relations for speed + ) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Use small batch size to trigger multiple batches + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=3, # Small batch size to test pagination + ) + stats = service.run() + + # 5 batches for 10 messages with batch_size=3, the last batch is empty + assert stats["batches"] == 5 + assert stats["total_messages"] == 10 + assert stats["filtered_messages"] == 10 + assert stats["total_deleted"] == 10 + + # All messages should be deleted + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + + def test_dry_run_does_not_delete( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test dry_run mode does not delete messages (B4).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create expired messages + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + dry_run=True, # Dry run mode + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 # Messages identified + assert stats["total_deleted"] == 0 # But NOT deleted + + # All messages should still exist + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + # Related records should also still exist + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + + def test_partial_plan_data_safe_default( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test when billing returns partial data, unknown tenants are preserved (B5).""" + # Arrange - Create 3 tenants + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service to return partial data + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing + partial_plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + }, + # tenants_data[2] is missing from response + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = partial_plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant[0]'s message should be deleted + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Check which messages were deleted + assert ( + db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + ) # Sandbox tenant's message deleted + + assert ( + db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + ) # Professional tenant's message preserved + + assert ( + db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + ) # Unknown tenant's message preserved (safe default) + + def test_empty_plan_data_skips_deletion( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test when billing returns empty data, skip deletion entirely (B6).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + msg_id = msg.id + db.session.commit() + + # Mock billing service to return empty data (simulating failure/no data scenario) + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} # Empty response, tenant plan unknown + + # Act - Should not raise exception, just skip deletion + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted when plan is unknown + assert stats["total_messages"] == 1 + assert stats["filtered_messages"] == 0 # Cannot determine sandbox messages + assert stats["total_deleted"] == 0 + + # Message should still exist (safe default - don't delete if plan is unknown) + assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + + def test_time_range_boundary_behavior( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: before range, in range, after range + msg_before = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from + with_relations=False, + ) + msg_before_id = msg_before.id + + msg_at_start = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) + with_relations=False, + ) + msg_at_start_id = msg_at_start.id + + msg_in_range = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range + with_relations=False, + ) + msg_in_range_id = msg_in_range.id + + msg_at_end = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) + with_relations=False, + ) + msg_at_end_id = msg_at_end.id + + msg_after = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before + with_relations=False, + ) + msg_after_id = msg_after.id + + db.session.commit() + + # Mock billing service + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Clean with specific time range [2024-01-10, 2024-01-20) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 12, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 12, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages in [start_from, end_before) should be deleted + assert stats["total_messages"] == 2 # Only in-range messages fetched + assert stats["filtered_messages"] == 2 # msg_at_start and msg_in_range + assert stats["total_deleted"] == 2 + + # Verify specific messages using stored IDs + # Before range, kept + assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + # At start (inclusive), deleted + assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + # In range, deleted + assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + # At end (exclusive), kept + assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + # After range, kept + assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + + def test_grace_period_scenarios( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test cleaning with different graceful period scenarios (B8).""" + # Arrange - Create 5 different tenants with different plan and expiration scenarios + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + graceful_period = 8 # Use 8 days for this test + + # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) + # Should NOT be deleted + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1_id = msg1.id + expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period + + # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) + # Should be deleted + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + msg2_id = msg2.id + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period + + # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) + # Should be deleted + account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app3 = self._create_app(tenant3, account3) + conv3 = self._create_conversation(app3) + msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + msg3_id = msg3.id + + # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) + # Should NOT be deleted + account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(tenant4, account4) + conv4 = self._create_conversation(app4) + msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + msg4_id = msg4.id + future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year + + # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) + # Should NOT be deleted (boundary is exclusive: > graceful_period) + account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app5 = self._create_app(tenant5, account5) + conv5 = self._create_conversation(app5) + msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + msg5_id = msg5.id + expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary + + db.session.commit() + + # Mock billing service with all scenarios + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_5_days_ago, + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, + }, + tenant3.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenant4.id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": future_expiration, + }, + tenant5.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_exactly_8_days_ago, + }, + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy( + graceful_period_days=graceful_period, + current_timestamp=now_timestamp, # Use fixed timestamp for deterministic behavior + ) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages from scenario 2 and 3 should be deleted + assert stats["total_messages"] == 5 # 5 tenants * 1 message + assert stats["filtered_messages"] == 2 + assert stats["total_deleted"] == 2 + + # Verify each scenario using saved IDs + assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted + assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted + assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept + assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + + def test_tenant_whitelist( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that whitelisted tenants' messages are not deleted (B9).""" + # Arrange - Create 3 sandbox tenants with expired messages + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service - all tenants are sandbox with no subscription + plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[2]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + } + + # Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not + whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant2's message should be deleted (not whitelisted) + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Verify tenant0's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + + # Verify tenant1's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + + # Verify tenant2's message was deleted (not whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + + def test_from_days_cleans_old_messages( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test from_days correctly cleans messages older than N days (B11).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create old messages (should be deleted - older than 30 days) + old_date = datetime.datetime.now() - datetime.timedelta(days=45) + old_msg_ids = [] + for i in range(3): + msg = self._create_message( + app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + ) + old_msg_ids.append(msg.id) + + # Create recent messages (should be kept - newer than 30 days) + recent_date = datetime.datetime.now() - datetime.timedelta(days=15) + recent_msg_ids = [] + for i in range(2): + msg = self._create_message( + app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + ) + recent_msg_ids.append(msg.id) + + db.session.commit() + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Act - Use from_days to clean messages older than 30 days + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_days( + policy=policy, + days=30, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 # Only old messages in range + assert stats["filtered_messages"] == 3 # Only old messages + assert stats["total_deleted"] == 3 + + # Old messages deleted + assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + # Recent messages kept + assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + + def test_whitelist_precedence_over_grace_period( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that whitelist takes precedence over grace period logic.""" + # Arrange - Create 2 sandbox tenants + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Tenant1: whitelisted, expired beyond grace period + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace + + # Tenant2: not whitelisted, within grace period + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace + + # Mock billing service + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_30_days_ago, # Beyond grace period + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, # Within grace period + }, + } + + # Setup whitelist - only tenant1 is whitelisted + whitelist = [tenant1.id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted + # tenant1: whitelisted (protected even though beyond grace period) + # tenant2: within grace period (not eligible for deletion) + assert stats["total_messages"] == 2 # 2 tenants * 1 message + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + # Verify both messages still exist + assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + + def test_empty_whitelist_deletes_eligible_messages( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" + # Arrange - Create sandbox tenant with expired messages + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg_ids.append(msg.id) + + # Mock billing service + plan_map = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Setup empty whitelist (default behavior from fixture) + mock_whitelist.return_value = [] + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - All messages should be deleted (no whitelist protection) + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 + assert stats["total_deleted"] == 3 + + # Verify all messages were deleted + assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..3b619195c7 --- /dev/null +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -0,0 +1,627 @@ +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from enums.cloud_plan import CloudPlan +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + SimpleMessage, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage: + """Helper to create a SimpleMessage with a fixed created_at timestamp.""" + return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1)) + + +def make_plan_provider(tenant_plans: dict) -> MagicMock: + """Helper to create a mock plan_provider that returns the given tenant_plans.""" + provider = MagicMock() + provider.return_value = tenant_plans + return provider + + +class TestBillingSandboxPolicyFilterMessageIds: + """Unit tests for BillingSandboxPolicy.filter_message_ids method.""" + + # Fixed timestamp for deterministic tests + CURRENT_TIMESTAMP = 1000000 + GRACEFUL_PERIOD_DAYS = 8 + GRACEFUL_PERIOD_SECONDS = GRACEFUL_PERIOD_DAYS * 24 * 60 * 60 + + def test_missing_tenant_mapping_excluded(self): + """Test that messages with missing app-to-tenant mapping are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {} # No mapping + tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}} + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_missing_tenant_plan_excluded(self): + """Test that messages with missing tenant plan are excluded (safe default).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = {} # No plans + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_non_sandbox_plan_excluded(self): + """Test that messages from non-sandbox plans (PROFESSIONAL/TEAM) are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.TEAM, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, # Only this one + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg3 (sandbox tenant) should be included + assert set(result) == {"msg3"} + + def test_whitelist_skip(self): + """Test that whitelisted tenants are excluded even if sandbox + expired.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), # Whitelisted - excluded + make_simple_message("msg2", "app2"), # Not whitelisted - included + make_simple_message("msg3", "app3"), # Whitelisted - excluded + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant1", "tenant3"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg2 should be included + assert set(result) == {"msg2"} + + def test_no_previous_subscription_included(self): + """Test that messages with expiration_date=-1 (no previous subscription) are included.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all messages should be included + assert set(result) == {"msg1", "msg2"} + + def test_within_grace_period_excluded(self): + """Test that messages within grace period are excluded.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_1_day_ago = now - (1 * 24 * 60 * 60) + expired_5_days_ago = now - (5 * 24 * 60 * 60) + expired_7_days_ago = now - (7 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, # 8 days + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all within 8-day grace period, none should be included + assert list(result) == [] + + def test_exactly_at_boundary_excluded(self): + """Test that messages exactly at grace period boundary are excluded (code uses >).""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_exactly_8_days_ago = now - self.GRACEFUL_PERIOD_SECONDS # Exactly at boundary + + messages = [make_simple_message("msg1", "app1")] + app_to_tenant = {"app1": "tenant1"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - exactly at boundary (==) should be excluded (code uses >) + assert list(result) == [] + + def test_beyond_grace_period_included(self): + """Test that messages beyond grace period are included.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond 8-day grace + expired_30_days_ago = now - (30 * 24 * 60 * 60) # Well beyond + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - both beyond grace period, should be included + assert set(result) == {"msg1", "msg2"} + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + plan_provider = make_plan_provider({"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_plan_provider_called_with_correct_tenant_ids(self): + """Test that plan_provider is called with correct tenant_ids.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant1"} # tenant1 appears twice + plan_provider = make_plan_provider({}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + policy.filter_message_ids(messages, app_to_tenant) + + # Assert - plan_provider should be called once with unique tenant_ids + plan_provider.assert_called_once() + called_tenant_ids = set(plan_provider.call_args[0][0]) + assert called_tenant_ids == {"tenant1", "tenant2"} + + def test_complex_mixed_scenario(self): + """Test complex scenario with mixed plans, expirations, whitelist, and missing mappings.""" + # Arrange + now = self.CURRENT_TIMESTAMP + sandbox_expired_old = now - (15 * 24 * 60 * 60) # Beyond grace + sandbox_expired_recent = now - (3 * 24 * 60 * 60) # Within grace + future_expiration = now + (30 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), # Sandbox, no subscription - included + make_simple_message("msg2", "app2"), # Sandbox, expired old - included + make_simple_message("msg3", "app3"), # Sandbox, within grace - excluded + make_simple_message("msg4", "app4"), # Team plan, active - excluded + make_simple_message("msg5", "app5"), # No tenant mapping - excluded + make_simple_message("msg6", "app6"), # No plan info - excluded + make_simple_message("msg7", "app7"), # Sandbox, expired old, whitelisted - excluded + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + "app4": "tenant4", + "app6": "tenant6", # Has mapping but no plan + "app7": "tenant7", + # app5 has no mapping + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent}, + "tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration}, + "tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + # tenant6 has no plan + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant7"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg1 and msg2 should be included + assert set(result) == {"msg1", "msg2"} + + +class TestBillingDisabledPolicyFilterMessageIds: + """Unit tests for BillingDisabledPolicy.filter_message_ids method.""" + + def test_returns_all_message_ids(self): + """Test that all message IDs are returned (order-preserving).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs returned in order + assert list(result) == ["msg1", "msg2", "msg3"] + + def test_ignores_app_to_tenant(self): + """Test that app_to_tenant mapping is ignored.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant: dict[str, str] = {} # Empty - should be ignored + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs still returned + assert list(result) == ["msg1", "msg2"] + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + +class TestCreateMessageCleanPolicy: + """Unit tests for create_message_clean_policy factory function.""" + + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_disabled_returns_billing_disabled_policy(self, mock_config): + """Test that BILLING_ENABLED=False returns BillingDisabledPolicy.""" + # Arrange + mock_config.BILLING_ENABLED = False + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + + # Assert + assert isinstance(policy, BillingDisabledPolicy) + + @patch("services.retention.conversation.messages_clean_policy.BillingService") + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service): + """Test that BillingSandboxPolicy is created with correct internal values.""" + # Arrange + mock_config.BILLING_ENABLED = True + whitelist = ["tenant1", "tenant2"] + mock_billing_service.get_expired_subscription_cleanup_whitelist.return_value = whitelist + mock_plan_provider = MagicMock() + mock_billing_service.get_plan_bulk_with_cache = mock_plan_provider + + # Act + policy = create_message_clean_policy(graceful_period_days=14, current_timestamp=1234567) + + # Assert + mock_billing_service.get_expired_subscription_cleanup_whitelist.assert_called_once() + assert isinstance(policy, BillingSandboxPolicy) + assert policy._graceful_period_days == 14 + assert list(policy._tenant_whitelist) == whitelist + assert policy._plan_provider == mock_plan_provider + assert policy._current_timestamp == 1234567 + + +class TestMessagesCleanServiceFromTimeRange: + """Unit tests for MessagesCleanService.from_time_range factory method.""" + + def test_start_from_end_before_raises_value_error(self): + """Test that start_from == end_before raises ValueError.""" + policy = BillingDisabledPolicy() + + # Arrange + same_time = datetime.datetime(2024, 1, 1, 12, 0, 0) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=same_time, + end_before=same_time, + ) + + # Arrange + start_from = datetime.datetime(2024, 12, 31) + end_before = datetime.datetime(2024, 1, 1) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=0, + ) + + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=-100, + ) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 12, 31, 23, 59, 59) + policy = BillingDisabledPolicy() + batch_size = 500 + dry_run = True + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from == start_from + assert service._end_before == end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + # Assert + assert service._batch_size == 1000 # default + assert service._dry_run is False # default + + +class TestMessagesCleanServiceFromDays: + """Unit tests for MessagesCleanService.from_days factory method.""" + + def test_days_raises_value_error(self): + """Test that days < 0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy=policy, days=-1) + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy, days=0) + + # Assert + assert service._end_before == fixed_now + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=0) + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=-500) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + policy = BillingDisabledPolicy() + days = 90 + batch_size = 500 + dry_run = True + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days( + policy=policy, + days=days, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=days) + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from is None + assert service._end_before == expected_end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30 + assert service._end_before == expected_end_before + assert service._batch_size == 1000 # default + assert service._dry_run is False # default