diff --git a/api/commands.py b/api/commands.py index a8d89ac200..cfc9ba24dd 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,7 +1,9 @@ import base64 +import datetime import json import logging import secrets +import time from typing import Any import click @@ -45,6 +47,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_service import PluginService +from services.sandbox_messages_clean_service import SandboxMessagesCleanService from tasks.remove_app_and_related_data_task import delete_draft_variables_batch logger = logging.getLogger(__name__) @@ -1900,3 +1903,76 @@ def migrate_oss( except Exception as e: db.session.rollback() click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) + + +@click.command("clean-expired-sandbox-messages", help="Clean expired sandbox messages.") +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") +@click.option( + "--graceful-period", + default=21, + show_default=True, + help="Graceful period in days after subscription expiration.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-after.", +) +@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleteing") +def clean_expired_sandbox_messages( + batch_size: int, + graceful_period: int, + start_from: datetime.datetime, + end_before: datetime.datetime, + dry_run: bool, +): + """ + Clean expired messages and related data for sandbox tenants. + """ + if not dify_config.BILLING_ENABLED: + click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow")) + return + + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + + start_at = time.perf_counter() + + try: + stats = SandboxMessagesCleanService.clean_sandbox_messages_by_time_range( + start_from=start_from, + end_before=end_before, + graceful_period=graceful_period, + batch_size=batch_size, + dry_run=dry_run, + ) + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Messages found: {stats['total_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise + + click.echo(click.style("Sandbox messages cleanup completed.", fg="green")) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 71a63168a5..0439272579 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_expired_sandbox_messages, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, @@ -54,6 +55,7 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + clean_expired_sandbox_messages, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2025_12_18_1639-649d817a739e_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2025_12_18_1639-649d817a739e_feat_add_created_at_id_index_to_messages.py new file mode 100644 index 0000000000..842441a35e --- /dev/null +++ b/api/migrations/versions/2025_12_18_1639-649d817a739e_feat_add_created_at_id_index_to_messages.py @@ -0,0 +1,33 @@ +"""feat: add created_at id index to messages + +Revision ID: 649d817a739e +Revises: 03ea244985ce +Create Date: 2025-12-18 16:39:33.090454 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '649d817a739e' +down_revision = '03ea244985ce' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_created_at_id_idx') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 88cb945b3f..8cf4912477 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -965,6 +965,7 @@ class Message(Base): Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), Index("message_created_at_idx", "created_at"), Index("message_app_mode_idx", "app_mode"), + Index("message_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 352a84b592..7b5fd76ee0 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -1,90 +1,54 @@ -import datetime import logging import time import click -from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config -from enums.cloud_plan import CloudPlan -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.model import ( - App, - Message, - MessageAgentThought, - MessageAnnotation, - MessageChain, - MessageFeedback, - MessageFile, -) -from models.web import SavedMessage -from services.feature_service import FeatureService +from services.sandbox_messages_clean_service import SandboxMessagesCleanService logger = logging.getLogger(__name__) -@app.celery.task(queue="dataset") +@app.celery.task(queue="retention") def clean_messages(): - click.echo(click.style("Start clean messages.", fg="green")) - start_at = time.perf_counter() - plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( - days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING - ) - while True: - try: - # Main query with join and filter - messages = ( - db.session.query(Message) - .where(Message.created_at < plan_sandbox_clean_message_day) - .order_by(Message.created_at.desc()) - .limit(100) - .all() - ) + """ + Clean expired messages from sandbox plan tenants. - except SQLAlchemyError: - raise - if not messages: - break - for message in messages: - app = db.session.query(App).filter_by(id=message.app_id).first() - if not app: - logger.warning( - "Expected App record to exist, but none was found, app_id=%s, message_id=%s", - message.app_id, - message.id, - ) - continue - features_cache_key = f"features:{app.tenant_id}" - plan_cache = redis_client.get(features_cache_key) - if plan_cache is None: - features = FeatureService.get_features(app.tenant_id) - redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) - plan = features.billing.subscription.plan - else: - plan = plan_cache.decode() - if plan == CloudPlan.SANDBOX: - # clean related message - db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(Message).where(Message.id == message.id).delete() - db.session.commit() - end_at = time.perf_counter() - click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green")) + This task uses SandboxMessagesCleanService to efficiently clean messages in batches. + """ + if not dify_config.BILLING_ENABLED: + click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow")) + return + + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + start_at = time.perf_counter() + + try: + stats = SandboxMessagesCleanService.clean_sandbox_messages_by_days( + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + graceful_period=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + ) + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Messages found: {stats['total_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise diff --git a/api/services/sandbox_messages_clean_service.py b/api/services/sandbox_messages_clean_service.py new file mode 100644 index 0000000000..5d457c88ff --- /dev/null +++ b/api/services/sandbox_messages_clean_service.py @@ -0,0 +1,488 @@ +import datetime +import json +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import cast + +from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import Session + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.model import ( + App, + AppAnnotationHitHistory, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +@dataclass +class SimpleMessage: + """Lightweight message info containing only essential fields for cleaning.""" + + id: str + app_id: str + created_at: datetime.datetime + + +class SandboxMessagesCleanService: + """ + Service for cleaning expired messages from sandbox plan tenants. + """ + + # Redis key prefix for tenant plan cache + PLAN_CACHE_KEY_PREFIX = "tenant_plan:" + # Cache TTL: 10 minutes + PLAN_CACHE_TTL = 600 + + @classmethod + def clean_sandbox_messages_by_time_range( + cls, + start_from: datetime.datetime, + end_before: datetime.datetime, + graceful_period: int = 21, + batch_size: int = 1000, + dry_run: bool = False, + ) -> dict[str, int]: + """ + Clean sandbox messages within a specific time range [start_from, end_before). + + Args: + start_from: Start time (inclusive) of the range + end_before: End time (exclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + Statistics about the cleaning operation + + Raises: + ValueError: If start_from >= end_before + """ + if start_from >= end_before: + raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + if graceful_period < 0: + raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0") + + logger.info("clean_messages: start_from=%s, end_before=%s, batch_size=%s", start_from, end_before, batch_size) + + return cls._clean_sandbox_messages_by_time_range( + start_from=start_from, + end_before=end_before, + graceful_period=graceful_period, + batch_size=batch_size, + dry_run=dry_run, + ) + + @classmethod + def clean_sandbox_messages_by_days( + cls, + days: int = 30, + graceful_period: int = 21, + batch_size: int = 1000, + dry_run: bool = False, + ) -> dict[str, int]: + """ + Clean sandbox messages older than specified days. + + Args: + days: Number of days to look back from now + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + Statistics about the cleaning operation + """ + if days < 0: + raise ValueError(f"days ({days}) must be greater than or equal to 0") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + if graceful_period < 0: + raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0") + + end_before = datetime.datetime.now() - datetime.timedelta(days=days) + + logger.info("clean_messages: days=%s, end_before=%s, batch_size=%s", days, end_before, batch_size) + + return cls._clean_sandbox_messages_by_time_range( + end_before=end_before, + start_from=None, + graceful_period=graceful_period, + batch_size=batch_size, + dry_run=dry_run, + ) + + @classmethod + def _clean_sandbox_messages_by_time_range( + cls, + end_before: datetime.datetime, + start_from: datetime.datetime | None = None, + graceful_period: int = 21, + batch_size: int = 1000, + dry_run: bool = False, + ) -> dict[str, int]: + """ + Internal method to clean sandbox messages within a time range using cursor-based pagination. + Time range is [start_from, end_before) - left-closed, right-open interval. + + Steps: + 1. Iterate messages using cursor pagination (by created_at, id) + 2. Extract app_ids from messages + 3. Query tenant_ids from apps + 4. Batch fetch subscription plans + 5. Delete messages from sandbox tenants + + Args: + end_before: End time (exclusive) of the range + start_from: Optional start time (inclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + Dict with statistics: batches, total_messages, total_deleted + """ + stats = { + "batches": 0, + "total_messages": 0, + "total_deleted": 0, + } + + if not dify_config.BILLING_ENABLED: + logger.info("clean_messages: billing is not enabled, skip cleaning messages") + return stats + + tenant_whitelist = cls._get_tenant_whitelist() + logger.info("clean_messages: tenant_whitelist=%s", tenant_whitelist) + + # Cursor-based pagination using (created_at, id) to avoid infinite loops + # and ensure proper ordering with time-based filtering + _cursor: tuple[datetime.datetime, str] | None = None + + logger.info( + "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s", + dry_run, + start_from, + end_before, + ) + + while True: + stats["batches"] += 1 + + # Step 1: Fetch a batch of messages using cursor + with Session(db.engine, expire_on_commit=False) as session: + msg_stmt = ( + select(Message.id, Message.app_id, Message.created_at) + .where(Message.created_at < end_before) + .order_by(Message.created_at, Message.id) + .limit(batch_size) + ) + + if start_from: + msg_stmt = msg_stmt.where(Message.created_at >= start_from) + + # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id) + # This translates to: + # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id) + if _cursor: + # Continuing from previous batch + msg_stmt = msg_stmt.where( + (Message.created_at > _cursor[0]) + | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1])) + ) + + raw_messages = list(session.execute(msg_stmt).all()) + messages = [ + SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at) + for msg_id, app_id, msg_created_at in raw_messages + ] + + if not messages: + logger.info("clean_messages (batch %s): no more messages to process", stats["batches"]) + break + + # Update cursor to the last message's (created_at, id) + _cursor = (messages[-1].created_at, messages[-1].id) + + # Step 2: Extract app_ids from this batch + app_ids = list({msg.app_id for msg in messages}) + + if not app_ids: + logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"]) + continue + + # Step 3: Query tenant_ids from apps + app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids)) + apps = list(session.execute(app_stmt).all()) + + if not apps: + logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"]) + continue + + # Step 4: End sesion to call billing API to avoid long-running transaction. + # Build app_id -> tenant_id mapping + app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps} + tenant_ids = list(set(app_to_tenant.values())) + + # Batch fetch subscription plans + tenant_plans = cls._batch_fetch_tenant_plans(tenant_ids) + + # Step 5: Filter messages from sandbox tenants + sandbox_message_ids = cls._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=tenant_whitelist, + graceful_period_days=graceful_period, + ) + + if not sandbox_message_ids: + logger.info("clean_messages (batch %s): no sandbox messages found, skip", stats["batches"]) + continue + + stats["total_messages"] += len(sandbox_message_ids) + + # Step 6: Batch delete messages and their relations + if not dry_run: + with Session(db.engine, expire_on_commit=False) as session: + # Delete related records first + cls._batch_delete_message_relations(session, sandbox_message_ids) + + # Delete messages + delete_stmt = delete(Message).where(Message.id.in_(sandbox_message_ids)) + delete_result = cast(CursorResult, session.execute(delete_stmt)) + messages_deleted = delete_result.rowcount + session.commit() + + stats["total_deleted"] += messages_deleted + + logger.info( + "clean_messages (batch %s): processed %s messages, deleted %s sandbox messages", + stats["batches"], + len(messages), + messages_deleted, + ) + else: + sample_ids = ", ".join(sample_id for sample_id in sandbox_message_ids[:5]) + logger.info( + "clean_messages (batch %s, dry_run): would delete %s sandbox messages, sample ids: %s", + stats["batches"], + len(sandbox_message_ids), + sample_ids, + ) + + logger.info( + "clean_messages completed: total batches: %s, total messages: %s, total deleted: %s", + stats["batches"], + stats["total_messages"], + stats["total_deleted"], + ) + + return stats + + @classmethod + def _filter_expired_sandbox_messages( + cls, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + tenant_plans: dict[str, SubscriptionPlan], + tenant_whitelist: Sequence[str], + graceful_period_days: int, + current_timestamp: int | None = None, + ) -> list[str]: + """ + Filter messages that should be deleted based on sandbox plan expiration. + + A message should be deleted if: + 1. It belongs to a sandbox tenant AND + 2. Either: + a) The tenant has no previous subscription (expiration_date == -1), OR + b) The subscription expired more than graceful_period_days ago + + Args: + messages: List of message objects with id and app_id attributes + app_to_tenant: Mapping from app_id to tenant_id + tenant_plans: Mapping from tenant_id to subscription plan info + graceful_period_days: Grace period in days after expiration + current_timestamp: Current Unix timestamp (defaults to now, injectable for testing) + + Returns: + List of message IDs that should be deleted + """ + if current_timestamp is None: + current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + sandbox_message_ids: list[str] = [] + graceful_period_seconds = graceful_period_days * 24 * 60 * 60 + + for msg in messages: + # Get tenant_id for this message's app + tenant_id = app_to_tenant.get(msg.app_id) + if not tenant_id: + continue + + # Skip tenant messages in whitelist + if tenant_id in tenant_whitelist: + continue + + # Get subscription plan for this tenant + tenant_plan = tenant_plans.get(tenant_id) + if not tenant_plan: + continue + + plan = str(tenant_plan["plan"]) + expiration_date = int(tenant_plan["expiration_date"]) + + # Only process sandbox plans + if plan != CloudPlan.SANDBOX: + continue + + # Case 1: No previous subscription (-1 means never had a paid subscription) + if expiration_date == -1: + sandbox_message_ids.append(msg.id) + continue + + # Case 2: Subscription expired beyond grace period + if current_timestamp - expiration_date > graceful_period_seconds: + sandbox_message_ids.append(msg.id) + + return sandbox_message_ids + + @classmethod + def _get_tenant_whitelist(cls) -> Sequence[str]: + return BillingService.get_expired_subscription_cleanup_whitelist() + + @classmethod + def _batch_fetch_tenant_plans(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]: + """ + Batch fetch tenant plans with Redis caching. + + This method uses a two-tier strategy: + 1. First, batch fetch from Redis cache using mget + 2. For cache misses, fetch from billing API + 3. Update Redis cache using pipeline for new entries + + Args: + tenant_ids: List of tenant IDs + + Returns: + Dict mapping tenant_id to SubscriptionPlan (with "plan" and "expiration_date" keys) + """ + if not tenant_ids: + return {} + + tenant_plans: dict[str, SubscriptionPlan] = {} + + # Step 1: Batch fetch from Redis cache using mget + redis_keys = [f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}" for tenant_id in tenant_ids] + try: + cached_values = redis_client.mget(redis_keys) + + # Map cached values back to tenant_ids + cache_hits: dict[str, SubscriptionPlan] = {} + cache_misses: list[str] = [] + + for tenant_id, cached_value in zip(tenant_ids, cached_values): + if cached_value: + # Redis returns bytes, decode to string and parse JSON + json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value + try: + plan_dict = json.loads(json_str) + if isinstance(plan_dict, dict) and "plan" in plan_dict: + cache_hits[tenant_id] = cast(SubscriptionPlan, plan_dict) + tenant_plans[tenant_id] = cast(SubscriptionPlan, plan_dict) + else: + cache_misses.append(tenant_id) + except json.JSONDecodeError: + cache_misses.append(tenant_id) + else: + cache_misses.append(tenant_id) + + logger.info( + "clean_messages: fetch_tenant_plans(cache hits=%s, cache misses=%s)", + len(cache_hits), + len(cache_misses), + ) + except Exception as e: + logger.warning("clean_messages: fetch_tenant_plans(redis mget failed: %s, falling back to API)", e) + cache_misses = list(tenant_ids) + + # Step 2: Fetch missing plans from billing API + if cache_misses: + bulk_plans = BillingService.get_plan_bulk(cache_misses) + + if bulk_plans: + plans_to_cache: dict[str, SubscriptionPlan] = {} + + for tenant_id, plan_dict in bulk_plans.items(): + if isinstance(plan_dict, dict): + tenant_plans[tenant_id] = plan_dict # type: ignore + plans_to_cache[tenant_id] = plan_dict # type: ignore + + # Step 3: Batch update Redis cache using pipeline + if plans_to_cache: + try: + pipe = redis_client.pipeline() + for tenant_id, plan_dict in plans_to_cache.items(): + redis_key = f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}" + # Serialize dict to JSON string + json_str = json.dumps(plan_dict) + pipe.setex(redis_key, cls.PLAN_CACHE_TTL, json_str) + pipe.execute() + + logger.info( + "clean_messages: cached %s new tenant plans to Redis", + len(plans_to_cache), + ) + except Exception as e: + logger.warning("clean_messages: Redis pipeline failed: %s", e) + + return tenant_plans + + @classmethod + def _batch_delete_message_relations(cls, session: Session, message_ids: Sequence[str]) -> None: + """ + Batch delete all related records for given message IDs. + + Args: + session: Database session + message_ids: List of message IDs to delete relations for + """ + if not message_ids: + return + + # Delete all related records in batch + session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids))) + + session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids))) + + session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids))) + + session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids))) + + session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids))) + + session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids))) + + session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids))) + + session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids))) diff --git a/api/tests/test_containers_integration_tests/services/test_sandbox_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_sandbox_messages_clean_service.py new file mode 100644 index 0000000000..1300046a58 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_sandbox_messages_clean_service.py @@ -0,0 +1,996 @@ +""" +Integration tests for SandboxMessagesCleanService using testcontainers. + +This module provides comprehensive integration tests for the sandbox message cleanup service +using TestContainers infrastructure with real PostgreSQL and Redis. +""" + +import datetime +import json +import uuid +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.sandbox_messages_clean_service import SandboxMessagesCleanService + + +class TestSandboxMessagesCleanServiceIntegration: + """Integration tests for SandboxMessagesCleanService._clean_sandbox_messages_by_time_range.""" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database before and after each test to ensure isolation.""" + yield + # Clear all test data in correct order (respecting foreign key constraints) + db.session.query(DatasetRetrieverResource).delete() + db.session.query(AppAnnotationHitHistory).delete() + db.session.query(SavedMessage).delete() + db.session.query(MessageFile).delete() + db.session.query(MessageAgentThought).delete() + db.session.query(MessageChain).delete() + db.session.query(MessageAnnotation).delete() + db.session.query(MessageFeedback).delete() + db.session.query(Message).delete() + db.session.query(Conversation).delete() + db.session.query(App).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + @pytest.fixture(autouse=True) + def cleanup_redis(self): + """Clean up Redis cache before each test.""" + # Clear tenant plan cache + try: + keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass # Redis might not be available in some test environments + yield + # Clean up after test + try: + keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass + + @pytest.fixture(autouse=True) + def mock_whitelist(self): + """Mock whitelist to return empty list by default.""" + with patch( + "services.sandbox_messages_clean_service.BillingService.get_expired_subscription_cleanup_whitelist" + ) as mock: + mock.return_value = [] + yield mock + + @pytest.fixture(autouse=True) + def mock_billing_enabled(self): + """Mock BILLING_ENABLED to be True for all tests.""" + with patch("services.sandbox_messages_clean_service.dify_config.BILLING_ENABLED", True): + yield + + def _create_account_and_tenant(self, plan="sandbox"): + """Helper to create account and tenant.""" + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.flush() + + tenant = Tenant( + name=fake.company(), + plan=plan, + status="normal", + ) + db.session.add(tenant) + db.session.flush() + + tenant_account_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + db.session.add(tenant_account_join) + db.session.commit() + + return account, tenant + + def _create_app(self, tenant, account): + """Helper to create an app.""" + fake = Faker() + + app = App( + tenant_id=tenant.id, + name=fake.company(), + description="Test app", + mode="chat", + enable_site=True, + enable_api=True, + api_rpm=60, + api_rph=3600, + is_demo=False, + is_public=False, + created_by=account.id, + updated_by=account.id, + ) + db.session.add(app) + db.session.commit() + + return app + + def _create_conversation(self, app): + """Helper to create a conversation.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=str(uuid.uuid4()), + model_provider="openai", + model_id="gpt-3.5-turbo", + mode="chat", + name="Test conversation", + inputs={}, + status="normal", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(conversation) + db.session.commit() + + return conversation + + def _create_message(self, app, conversation, created_at=None, with_relations=True): + """Helper to create a message with optional related records.""" + if created_at is None: + created_at = datetime.datetime.now() + + message = Message( + app_id=app.id, + conversation_id=conversation.id, + model_provider="openai", + model_id="gpt-3.5-turbo", + inputs={}, + query="Test query", + answer="Test answer", + message=[{"role": "user", "text": "Test message"}], + message_tokens=10, + message_unit_price=Decimal("0.001"), + answer_tokens=20, + answer_unit_price=Decimal("0.002"), + total_price=Decimal("0.003"), + currency="USD", + from_source="api", + from_account_id=conversation.from_end_user_id, + created_at=created_at, + ) + db.session.add(message) + db.session.flush() + + if with_relations: + self._create_message_relations(message) + + db.session.commit() + return message + + def _create_message_relations(self, message): + """Helper to create all message-related records.""" + # MessageFeedback + feedback = MessageFeedback( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + rating="like", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(feedback) + + # MessageAnnotation + annotation = MessageAnnotation( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + question="Test question", + content="Test annotation", + account_id=message.from_account_id, + ) + db.session.add(annotation) + + # MessageChain + chain = MessageChain( + message_id=message.id, + type="system", + input=json.dumps({"test": "input"}), + output=json.dumps({"test": "output"}), + ) + db.session.add(chain) + db.session.flush() + + # MessageFile + file = MessageFile( + message_id=message.id, + type="image", + transfer_method="local_file", + url="http://example.com/test.jpg", + belongs_to="user", + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(file) + + # SavedMessage + saved = SavedMessage( + app_id=message.app_id, + message_id=message.id, + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(saved) + + db.session.flush() + + # AppAnnotationHitHistory + hit = AppAnnotationHitHistory( + app_id=message.app_id, + annotation_id=annotation.id, + message_id=message.id, + source="annotation", + question="Test question", + account_id=message.from_account_id, + annotation_question="Test annotation question", + annotation_content="Test annotation content", + ) + db.session.add(hit) + + # DatasetRetrieverResource + resource = DatasetRetrieverResource( + message_id=message.id, + position=1, + dataset_id=str(uuid.uuid4()), + dataset_name="Test dataset", + document_id=str(uuid.uuid4()), + document_name="Test document", + data_source_type="upload_file", + segment_id=str(uuid.uuid4()), + score=0.9, + content="Test content", + hit_count=1, + word_count=10, + segment_position=1, + index_node_hash="test_hash", + retriever_from="dataset", + created_by=message.from_account_id, + ) + db.session.add(resource) + + def test_clean_no_messages_to_delete(self, db_session_with_containers): + """Test cleaning when there are no messages to delete.""" + # Arrange + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} + + # Act + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=100, + ) + + # Assert + # Even with no messages, the loop runs once to check + assert stats["batches"] == 1 + assert stats["total_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_clean_mixed_sandbox_and_paid_tenants(self, db_session_with_containers): + """Test cleaning with mixed sandbox and paid tenants, correctly filtering sandbox messages.""" + # Arrange - Create sandbox tenants with expired messages + sandbox_tenants = [] + sandbox_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan="sandbox") + sandbox_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 3 expired messages per sandbox tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + sandbox_message_ids.append(msg.id) + + # Create paid tenants with expired messages (should NOT be deleted) + paid_tenants = [] + paid_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan="professional") + paid_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 2 expired messages per paid tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(2): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + paid_message_ids.append(msg.id) + + # Mock billing service - return plan and expiration_date + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period + + plan_map = {} + for tenant in sandbox_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_15_days_ago, + } + for tenant in paid_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + } + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=7, + batch_size=100, + ) + + # Assert + assert stats["total_messages"] == 6 # 2 sandbox tenants * 3 messages + assert stats["total_deleted"] == 6 + + # Only sandbox messages should be deleted + assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + # Paid messages should remain + assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + + # Related records of sandbox messages should be deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 + assert ( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + == 0 + ) + + def test_clean_with_cursor_pagination(self, db_session_with_containers): + """Test cursor pagination works correctly across multiple batches.""" + # Arrange - Create sandbox tenant with messages that will span multiple batches + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 10 expired messages with different timestamps + base_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(10): + msg = self._create_message( + app, + conv, + created_at=base_date + datetime.timedelta(hours=i), + with_relations=False, # Skip relations for speed + ) + message_ids.append(msg.id) + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Use small batch size to trigger multiple batches + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=3, # Small batch size to test pagination + ) + + # 5 batches for 10 messages with batch_size=3, the last batch is empty + assert stats["batches"] == 5 + assert stats["total_messages"] == 10 + assert stats["total_deleted"] == 10 + + # All messages should be deleted + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + + def test_clean_with_dry_run(self, db_session_with_containers): + """Test dry_run mode does not delete messages.""" + # Arrange + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create expired messages + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + message_ids.append(msg.id) + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=100, + dry_run=True, # Dry run mode + ) + + # Assert + assert stats["total_messages"] == 3 # Messages identified + assert stats["total_deleted"] == 0 # But NOT deleted + + # All messages should still exist + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + # Related records should also still exist + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + + def test_clean_with_billing_partial_exception_some_known_plans(self, db_session_with_containers): + """Test when billing service fails but returns partial data, only delete known sandbox messages.""" + # Arrange - Create 3 tenants + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service to return partial data with new structure + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing + partial_plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + }, + # tenants_data[2] is missing from response + } + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = partial_plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=100, + ) + + # Assert - Only tenant[0]'s message should be deleted + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Check which messages were deleted + assert ( + db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + ) # Sandbox tenant's message deleted + + assert ( + db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + ) # Professional tenant's message preserved + + assert ( + db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + ) # Unknown tenant's message preserved (safe default) + + def test_clean_with_billing_exception_no_data(self, db_session_with_containers): + """Test when billing service returns empty data, skip deletion for that batch.""" + # Arrange + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg_id = None + msg = self._create_message(app, conv, created_at=expired_date) + msg_id = msg.id # Store ID before any operations + db.session.commit() + + # Mock billing service to return empty data (simulating failure/no data scenario) + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} # Empty response, tenant plan unknown + + # Act - Should not raise exception, just skip deletion + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=100, + ) + + # Assert - No messages should be deleted when plan is unknown + assert stats["total_messages"] == 0 # Cannot determine sandbox messages + assert stats["total_deleted"] == 0 + + # Message should still exist (safe default - don't delete if plan is unknown) + assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + + def test_redis_cache_for_tenant_plans(self, db_session_with_containers): + """Test that tenant plans are cached in Redis and reused.""" + # Arrange + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages in two batches (to test cache reuse) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + batch1_msgs = [] + for i in range(2): + msg = self._create_message( + app, conv, created_at=expired_date + datetime.timedelta(hours=i), with_relations=False + ) + batch1_msgs.append(msg.id) + + batch2_msgs = [] + for i in range(2): + msg = self._create_message( + app, conv, created_at=expired_date + datetime.timedelta(hours=10 + i), with_relations=False + ) + batch2_msgs.append(msg.id) + + # Mock billing service with new structure + mock_get_plan_bulk = MagicMock( + return_value={ + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + ) + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk", mock_get_plan_bulk): + # Act - First call + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats1 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=2, # Process 2 messages per batch + ) + + # Check billing service was called (cache miss) + assert mock_get_plan_bulk.call_count == 1 + first_call_count = mock_get_plan_bulk.call_count + + # Verify Redis cache was populated + cache_key = f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}{tenant.id}" + cached_plan = redis_client.get(cache_key) + assert cached_plan is not None + cached_plan_data = json.loads(cached_plan.decode("utf-8")) + assert cached_plan_data["plan"] == CloudPlan.SANDBOX + assert cached_plan_data["expiration_date"] == -1 + + # Act - Second call with same tenant (should use cache) + # Create more messages for the same tenant + batch3_msgs = [] + for i in range(2): + msg = self._create_message( + app, conv, created_at=expired_date + datetime.timedelta(hours=20 + i), with_relations=False + ) + batch3_msgs.append(msg.id) + + stats2 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=2, + ) + + # Assert - Billing service should not be called again (cache hit) + # The call count should be the same + assert mock_get_plan_bulk.call_count == first_call_count # Same tenant, should use cache + + # Verify all messages were deleted + total_expected = len(batch1_msgs) + len(batch2_msgs) + len(batch3_msgs) + assert stats1["total_deleted"] + stats2["total_deleted"] == total_expected + + def test_time_range_filtering(self, db_session_with_containers): + """Test that messages are correctly filtered by [start_from, end_before) time range.""" + # Arrange + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + base_date = datetime.datetime(2024, 1, 15, 12, 0, 0) + + # Create messages: before range, in range, after range + msg_before = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from + with_relations=False, + ) + msg_before_id = msg_before.id + + msg_at_start = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) + with_relations=False, + ) + msg_at_start_id = msg_at_start.id + + msg_in_range = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range + with_relations=False, + ) + msg_in_range_id = msg_in_range.id + + msg_at_end = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) + with_relations=False, + ) + msg_at_end_id = msg_at_end.id + + msg_after = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before + with_relations=False, + ) + msg_after_id = msg_after.id + + db.session.commit() # Commit all messages + + # Mock billing service with new structure + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Clean with specific time range [2024-01-10, 2024-01-20) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + start_from=datetime.datetime(2024, 1, 10, 12, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 12, 0, 0), + graceful_period=21, # Use default graceful period + batch_size=100, + ) + + # Assert - Only messages in [start_from, end_before) should be deleted + assert stats["total_messages"] == 2 # msg_at_start and msg_in_range + assert stats["total_deleted"] == 2 + + # Verify specific messages using stored IDs + # Before range, kept + assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + # At start (inclusive), deleted + assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + # In range, deleted + assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + # At end (exclusive), kept + assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + # After range, kept + assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + + def test_clean_with_graceful_period_scenarios(self, db_session_with_containers): + """Test cleaning with different graceful period scenarios.""" + # Arrange - Create 5 different tenants with different plan and expiration scenarios + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + graceful_period = 8 # Use 8 days for this test to validate boundary conditions + + # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) + # Should NOT be deleted + account1, tenant1 = self._create_account_and_tenant(plan="sandbox") + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1_id = msg1.id # Save ID before potential deletion + expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period + + # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) + # Should be deleted + account2, tenant2 = self._create_account_and_tenant(plan="sandbox") + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + msg2_id = msg2.id # Save ID before potential deletion + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period + + # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) + # Should be deleted + account3, tenant3 = self._create_account_and_tenant(plan="sandbox") + app3 = self._create_app(tenant3, account3) + conv3 = self._create_conversation(app3) + msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + msg3_id = msg3.id # Save ID before potential deletion + + # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) + # Should NOT be deleted + account4, tenant4 = self._create_account_and_tenant(plan="professional") + app4 = self._create_app(tenant4, account4) + conv4 = self._create_conversation(app4) + msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + msg4_id = msg4.id # Save ID before potential deletion + future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year + + # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) + # Should NOT be deleted (boundary is exclusive: > graceful_period) + account5, tenant5 = self._create_account_and_tenant(plan="sandbox") + app5 = self._create_app(tenant5, account5) + conv5 = self._create_conversation(app5) + msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + msg5_id = msg5.id # Save ID before potential deletion + expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary + + db.session.commit() + + # Mock billing service with all scenarios + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_5_days_ago, + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, + }, + tenant3.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenant4.id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": future_expiration, + }, + tenant5.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_exactly_8_days_ago, + }, + } + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + + # Mock datetime.now() to use the same timestamp as test setup + # This ensures deterministic behavior for boundary conditions (scenario 5) + with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime: + mock_datetime.datetime.now.return_value = datetime.datetime.fromtimestamp( + now_timestamp, tz=datetime.UTC + ) + mock_datetime.timedelta = datetime.timedelta # Keep original timedelta + + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=graceful_period, + batch_size=100, + ) + + # Assert - Only messages from scenario 2 and 3 should be deleted + assert stats["total_messages"] == 2 + assert stats["total_deleted"] == 2 + + # Verify each scenario using saved IDs + assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted + assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted + assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept + assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + + def test_clean_with_tenant_whitelist(self, db_session_with_containers, mock_whitelist): + """Test that whitelisted tenants' messages are not deleted even if they are sandbox and expired.""" + # Arrange - Create 3 sandbox tenants with expired messages + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service - all tenants are sandbox with no subscription + plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + tenants_data[2]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + } + + # Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not + whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id] + mock_whitelist.return_value = whitelist + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=100, + ) + + # Assert - Only tenant2's message should be deleted (not whitelisted) + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Verify tenant0's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + + # Verify tenant1's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + + # Verify tenant2's message was deleted (not whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + + def test_clean_with_whitelist_and_grace_period(self, db_session_with_containers, mock_whitelist): + """Test that whitelist takes precedence over grace period logic.""" + # Arrange - Create 2 sandbox tenants + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Tenant1: whitelisted, expired beyond grace period + account1, tenant1 = self._create_account_and_tenant(plan="sandbox") + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace + + # Tenant2: not whitelisted, within grace period + account2, tenant2 = self._create_account_and_tenant(plan="sandbox") + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace + + # Mock billing service + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_30_days_ago, # Beyond grace period + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, # Within grace period + }, + } + + # Setup whitelist - only tenant1 is whitelisted + whitelist = [tenant1.id] + mock_whitelist.return_value = whitelist + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=100, + ) + + # Assert - No messages should be deleted + # tenant1: whitelisted (would be deleted based on grace period, but protected by whitelist) + # tenant2: within grace period (not eligible for deletion) + assert stats["total_messages"] == 0 + assert stats["total_deleted"] == 0 + + # Verify both messages still exist + assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + + def test_clean_with_empty_whitelist(self, db_session_with_containers, mock_whitelist): + """Test that empty whitelist behaves as no whitelist (all eligible messages are deleted).""" + # Arrange - Create sandbox tenant with expired messages + account, tenant = self._create_account_and_tenant(plan="sandbox") + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg_ids.append(msg.id) + + # Mock billing service + plan_map = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Setup empty whitelist (default behavior from fixture) + mock_whitelist.return_value = [] + + with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range( + end_before=end_before, + graceful_period=21, # Use default graceful period + batch_size=100, + ) + + # Assert - All messages should be deleted (no whitelist protection) + assert stats["total_messages"] == 3 + assert stats["total_deleted"] == 3 + + # Verify all messages were deleted + assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/unit_tests/services/test_sandbox_messages_clean_service.py b/api/tests/unit_tests/services/test_sandbox_messages_clean_service.py new file mode 100644 index 0000000000..3ecc4be22d --- /dev/null +++ b/api/tests/unit_tests/services/test_sandbox_messages_clean_service.py @@ -0,0 +1,588 @@ +""" +Unit tests for SandboxMessagesCleanService. + +This module tests parameter validation, method invocation, and error handling +without database dependencies (using mocks). +""" + +import datetime +from unittest.mock import patch + +import pytest + +from enums.cloud_plan import CloudPlan +from services.sandbox_messages_clean_service import SandboxMessagesCleanService + + +class MockMessage: + """Mock message object for testing.""" + + def __init__(self, id: str, app_id: str, created_at: datetime.datetime | None = None): + self.id = id + self.app_id = app_id + self.created_at = created_at or datetime.datetime.now() + + +class TestFilterExpiredSandboxMessages: + """Unit tests for _filter_expired_sandbox_messages method.""" + + def test_filter_missing_tenant_mapping(self): + """Test that messages with missing app-to-tenant mapping are excluded.""" + # Arrange + messages = [ + MockMessage("msg1", "app1"), + MockMessage("msg2", "app2"), + ] + app_to_tenant = {} # No mapping + tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}} + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=[], + graceful_period_days=8, + current_timestamp=1000000, + ) + + # Assert + assert result == [] + + def test_filter_missing_tenant_plan(self): + """Test that messages with missing tenant plan are excluded.""" + # Arrange + messages = [ + MockMessage("msg1", "app1"), + MockMessage("msg2", "app2"), + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + } + tenant_plans = {} # No plans + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=[], + graceful_period_days=8, + current_timestamp=1000000, + ) + + # Assert + assert result == [] + + def test_filter_no_previous_subscription(self): + """Test that messages with no previous subscription (expiration_date=-1) are deleted.""" + # Arrange + messages = [ + MockMessage("msg1", "app1"), + MockMessage("msg2", "app2"), + MockMessage("msg3", "app3"), + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=[], + graceful_period_days=8, + current_timestamp=1000000, + ) + + # Assert - all messages should be deleted + assert set(result) == {"msg1", "msg2", "msg3"} + + def test_filter_all_within_grace_period(self): + """Test that no messages are deleted when all are within grace period.""" + # Arrange + now = 1000000 + # All expired recently (within 8 day grace period) + expired_1_day_ago = now - (1 * 24 * 60 * 60) + expired_3_days_ago = now - (3 * 24 * 60 * 60) + expired_7_days_ago = now - (7 * 24 * 60 * 60) + + messages = [ + MockMessage("msg1", "app1"), + MockMessage("msg2", "app2"), + MockMessage("msg3", "app3"), + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_3_days_ago}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago}, + } + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=[], + graceful_period_days=8, + current_timestamp=now, + ) + + # Assert - no messages should be deleted + assert result == [] + + def test_filter_partial_expired_beyond_grace_period(self): + """Test filtering when some messages expired beyond grace period.""" + # Arrange + now = 1000000 + graceful_period = 8 + + # Different expiration scenarios + expired_5_days_ago = now - (5 * 24 * 60 * 60) # Within grace - keep + expired_10_days_ago = now - (10 * 24 * 60 * 60) # Beyond grace - delete + expired_30_days_ago = now - (30 * 24 * 60 * 60) # Beyond grace - delete + expired_exactly_8_days_ago = now - (8 * 24 * 60 * 60) # Exactly at boundary - keep + expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond - delete + + messages = [ + MockMessage("msg1", "app1"), # Within grace + MockMessage("msg2", "app2"), # Beyond grace + MockMessage("msg3", "app3"), # Beyond grace + MockMessage("msg4", "app4"), # No subscription - delete + MockMessage("msg5", "app5"), # Exactly at boundary + MockMessage("msg6", "app6"), # Just beyond grace + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + "app4": "tenant4", + "app5": "tenant5", + "app6": "tenant6", + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_10_days_ago}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago}, + "tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant5": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago}, + "tenant6": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago}, + } + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=[], + graceful_period_days=graceful_period, + current_timestamp=now, + ) + + # Assert - msg2, msg3, msg4, msg6 should be deleted + # msg1 and msg5 are within/at grace period boundary + assert set(result) == {"msg2", "msg3", "msg4", "msg6"} + + def test_filter_complex_mixed_scenario(self): + """Test complex scenario with mixed plans, expirations, and missing mappings.""" + # Arrange + now = 1000000 + sandbox_expired_old = now - (15 * 24 * 60 * 60) # 15 days ago - beyond grace + sandbox_expired_recent = now - (3 * 24 * 60 * 60) # 3 days ago - within grace + future_expiration = now + (30 * 24 * 60 * 60) # 30 days in future - active paid plan + + messages = [ + MockMessage("msg1", "app1"), # Sandbox, no subscription - delete + MockMessage("msg2", "app2"), # Sandbox, expired old - delete + MockMessage("msg3", "app3"), # Sandbox, within grace - keep + MockMessage("msg4", "app4"), # Team plan, active - keep + MockMessage("msg5", "app5"), # No tenant mapping - keep + MockMessage("msg6", "app6"), # No plan info - keep + MockMessage("msg7", "app7"), # Sandbox, expired old - delete + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + "app4": "tenant4", + "app6": "tenant6", # Has mapping but no plan + "app7": "tenant7", + # app5 has no mapping + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent}, + "tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration}, + "tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + # tenant6 has no plan + } + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=[], + graceful_period_days=8, + current_timestamp=now, + ) + + # Assert - only sandbox expired beyond grace period and no subscription + assert set(result) == {"msg1", "msg2", "msg7"} + + def test_filter_empty_inputs(self): + """Test filtering with empty inputs returns empty list.""" + # Arrange - empty messages + result1 = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=[], + app_to_tenant={"app1": "tenant1"}, + tenant_plans={"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}, + tenant_whitelist=[], + graceful_period_days=8, + current_timestamp=1000000, + ) + + # Assert + assert result1 == [] + + def test_filter_uses_default_timestamp(self): + """Test that method uses current time when timestamp not provided.""" + # Arrange + messages = [MockMessage("msg1", "app1")] + app_to_tenant = {"app1": "tenant1"} + tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}} + + # Act - don't provide current_timestamp + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=[], + graceful_period_days=8, + # current_timestamp not provided - should use datetime.now() + ) + + # Assert - should still work and return msg1 (no subscription) + assert result == ["msg1"] + + def test_filter_with_whitelist(self): + """Test that messages from whitelisted tenants are excluded from deletion.""" + # Arrange + messages = [ + MockMessage("msg1", "app1"), # Whitelisted tenant - should be kept + MockMessage("msg2", "app2"), # Not whitelisted - should be deleted + MockMessage("msg3", "app3"), # Whitelisted tenant - should be kept + MockMessage("msg4", "app4"), # Not whitelisted - should be deleted + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + "app4": "tenant4", + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + tenant_whitelist = ["tenant1", "tenant3"] # Whitelist tenant1 and tenant3 + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=tenant_whitelist, + graceful_period_days=8, + current_timestamp=1000000, + ) + + # Assert - only msg2 and msg4 should be deleted (not whitelisted) + assert set(result) == {"msg2", "msg4"} + + def test_filter_with_whitelist_and_grace_period(self): + """Test whitelist takes precedence over grace period logic.""" + # Arrange + now = 1000000 + expired_long_ago = now - (30 * 24 * 60 * 60) # Expired 30 days ago + + messages = [ + MockMessage("msg1", "app1"), # Whitelisted, expired long ago - should be kept + MockMessage("msg2", "app2"), # Not whitelisted, expired long ago - should be deleted + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago}, + } + tenant_whitelist = ["tenant1"] # Only tenant1 is whitelisted + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=tenant_whitelist, + graceful_period_days=8, + current_timestamp=now, + ) + + # Assert - only msg2 should be deleted + assert result == ["msg2"] + + def test_filter_whitelist_with_non_sandbox_plans(self): + """Test that whitelist only affects sandbox plan messages.""" + # Arrange + now = 1000000 + future_expiration = now + (30 * 24 * 60 * 60) + + messages = [ + MockMessage("msg1", "app1"), # Sandbox, whitelisted - kept + MockMessage("msg2", "app2"), # Team plan, whitelisted - kept (not sandbox) + MockMessage("msg3", "app3"), # Sandbox, not whitelisted - deleted + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + tenant_whitelist = ["tenant1", "tenant2"] + + # Act + result = SandboxMessagesCleanService._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + tenant_whitelist=tenant_whitelist, + graceful_period_days=8, + current_timestamp=now, + ) + + # Assert - only msg3 should be deleted (sandbox, not whitelisted) + assert result == ["msg3"] + + +class TestCleanSandboxMessagesByTimeRange: + """Unit tests for clean_sandbox_messages_by_time_range method.""" + + @patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range") + def test_valid_time_range_and_args(self, mock_clean): + """Test with valid time range and other parameters.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 12, 31, 23, 59, 59) + batch_size = 500 + dry_run = True + + mock_clean.return_value = { + "batches": 5, + "total_messages": 100, + "total_deleted": 100, + } + + # Act + SandboxMessagesCleanService.clean_sandbox_messages_by_time_range( + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert, expected no exception raised + mock_clean.assert_called_once_with( + start_from=start_from, + end_before=end_before, + graceful_period=21, + batch_size=batch_size, + dry_run=dry_run, + ) + + @patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range") + def test_with_default_args(self, mock_clean): + """Test with default args.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + + mock_clean.return_value = { + "batches": 2, + "total_messages": 50, + "total_deleted": 0, + } + + # Act + SandboxMessagesCleanService.clean_sandbox_messages_by_time_range( + start_from=start_from, + end_before=end_before, + ) + + # Assert + mock_clean.assert_called_once_with( + start_from=start_from, + end_before=end_before, + graceful_period=21, + batch_size=1000, + dry_run=False, + ) + + def test_invalid_time_range(self): + """Test invalid time range raises ValueError.""" + # Arrange + same_time = datetime.datetime(2024, 1, 1, 12, 0, 0) + + # Act & Assert start equals end + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + SandboxMessagesCleanService.clean_sandbox_messages_by_time_range( + start_from=same_time, + end_before=same_time, + ) + + # Arrange + start_from = datetime.datetime(2024, 12, 31) + end_before = datetime.datetime(2024, 1, 1) + + # Act & Assert start after end + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + SandboxMessagesCleanService.clean_sandbox_messages_by_time_range( + start_from=start_from, + end_before=end_before, + ) + + def test_invalid_batch_size(self): + """Test invalid batch_size raises ValueError.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + + # Act & Assert batch_size = 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + SandboxMessagesCleanService.clean_sandbox_messages_by_time_range( + start_from=start_from, + end_before=end_before, + batch_size=0, + ) + + # Act & Assert batch_size < 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + SandboxMessagesCleanService.clean_sandbox_messages_by_time_range( + start_from=start_from, + end_before=end_before, + batch_size=-100, + ) + + +class TestCleanSandboxMessagesByDays: + """Unit tests for clean_sandbox_messages_by_days method.""" + + @patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range") + def test_default_days(self, mock_clean): + """Test with default 30 days.""" + # Arrange + mock_clean.return_value = {"batches": 3, "total_messages": 75, "total_deleted": 75} + + # Act + with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta # Keep original timedelta + + SandboxMessagesCleanService.clean_sandbox_messages_by_days() + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30 + mock_clean.assert_called_once_with( + end_before=expected_end_before, + start_from=None, + graceful_period=21, + batch_size=1000, + dry_run=False, + ) + + @patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range") + def test_custom_days(self, mock_clean): + """Test with custom number of days.""" + # Arrange + custom_days = 90 + mock_clean.return_value = {"batches": 10, "total_messages": 500, "total_deleted": 500} + + # Act + with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta # Keep original timedelta + + result = SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=custom_days) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=custom_days) + mock_clean.assert_called_once_with( + end_before=expected_end_before, + start_from=None, + graceful_period=21, + batch_size=1000, + dry_run=False, + ) + + @patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range") + def test_zero_days(self, mock_clean): + """Test with days=0 (clean all messages before now).""" + # Arrange + mock_clean.return_value = {"batches": 0, "total_messages": 0, "total_deleted": 0} + + # Act + with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta # Keep original timedelta + + SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=0) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=0) # same as fixed_now + mock_clean.assert_called_once_with( + end_before=expected_end_before, + start_from=None, + graceful_period=21, + batch_size=1000, + dry_run=False, + ) + + def test_invalid_batch_size(self): + """Test invalid batch_size raises ValueError.""" + # Act & Assert batch_size = 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + SandboxMessagesCleanService.clean_sandbox_messages_by_days( + days=30, + batch_size=0, + ) + + # Act & Assert batch_size < 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + SandboxMessagesCleanService.clean_sandbox_messages_by_days( + days=30, + batch_size=-500, + )