diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 6c5d6f4135..b96db5a390 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -65,6 +65,9 @@ jobs: defaults: run: working-directory: ./web + permissions: + checks: write + pull-requests: read steps: - name: Checkout code @@ -103,7 +106,15 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web run: | - pnpm run lint + pnpm run lint:report + continue-on-error: true + + - name: Annotate Code + if: steps.changed-files.outputs.any_changed == 'true' + uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae + with: + eslint-report: web/eslint_report.json + github-token: ${{ secrets.GITHUB_TOKEN }} - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/api/commands.py b/api/commands.py index 20ce22a6c7..e223df74d4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -3,6 +3,7 @@ import datetime import json import logging import secrets +import time from typing import Any import click @@ -46,6 +47,8 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_service import PluginService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from tasks.remove_app_and_related_data_task import delete_draft_variables_batch @@ -2172,3 +2175,79 @@ def migrate_oss( except Exception as e: db.session.rollback() click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) + + +@click.command("clean-expired-messages", help="Clean expired messages.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Upper bound (exclusive) for created_at.", +) +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") +@click.option( + "--graceful-period", + default=21, + show_default=True, + help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", +) +@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting") +def clean_expired_messages( + batch_size: int, + graceful_period: int, + start_from: datetime.datetime, + end_before: datetime.datetime, + dry_run: bool, +): + """ + Clean expired messages and related data for tenants based on clean policy. + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + # NOTE: graceful_period will be ignored when billing is disabled. + policy = create_message_clean_policy(graceful_period_days=graceful_period) + + # Create and run the cleanup service + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise + + click.echo(click.style("messages cleanup completed.", fg="green")) diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index c79764983b..d37217e168 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -6,6 +6,7 @@ from .create_site_record_when_app_created import handle as handle_create_site_re from .delete_tool_parameters_cache_when_sync_draft_workflow import ( handle as handle_delete_tool_parameters_cache_when_sync_draft_workflow, ) +from .queue_credential_sync_when_tenant_created import handle as handle_queue_credential_sync_when_tenant_created from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published @@ -30,6 +31,7 @@ __all__ = [ "handle_create_installed_app_when_app_created", "handle_create_site_record_when_app_created", "handle_delete_tool_parameters_cache_when_sync_draft_workflow", + "handle_queue_credential_sync_when_tenant_created", "handle_sync_plugin_trigger_when_app_created", "handle_sync_webhook_when_app_created", "handle_sync_workflow_schedule_when_app_published", diff --git a/api/events/event_handlers/queue_credential_sync_when_tenant_created.py b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py new file mode 100644 index 0000000000..6566c214b0 --- /dev/null +++ b/api/events/event_handlers/queue_credential_sync_when_tenant_created.py @@ -0,0 +1,19 @@ +from configs import dify_config +from events.tenant_event import tenant_was_created +from services.enterprise.workspace_sync import WorkspaceSyncService + + +@tenant_was_created.connect +def handle(sender, **kwargs): + """Queue credential sync when a tenant/workspace is created.""" + # Only queue sync tasks if plugin manager (enterprise feature) is enabled + if not dify_config.ENTERPRISE_ENABLED: + return + + tenant = sender + + # Determine source from kwargs if available, otherwise use generic + source = kwargs.get("source", "tenant_created") + + # Queue credential sync task to Redis for enterprise backend to process + WorkspaceSyncService.queue_credential_sync(tenant.id, source=source) diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index c32130d377..51e2c6cdd5 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_expired_messages, clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, @@ -58,6 +59,7 @@ def init_app(app: DifyApp): transform_datasource_credentials, install_rag_pipeline_plugins, clean_workflow_runs, + clean_expired_messages, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py new file mode 100644 index 0000000000..758369ba99 --- /dev/null +++ b/api/migrations/versions/2026_01_12_1729-3334862ee907_feat_add_created_at_id_index_to_messages.py @@ -0,0 +1,33 @@ +"""feat: add created_at id index to messages + +Revision ID: 3334862ee907 +Revises: 905527cc8fd3 +Create Date: 2026-01-12 17:29:44.846544 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '3334862ee907' +down_revision = '905527cc8fd3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_created_at_id_idx') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 68903e86eb..d6a0aa3bb3 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -968,6 +968,7 @@ class Message(Base): Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), Index("message_created_at_idx", "created_at"), Index("message_app_mode_idx", "app_mode"), + Index("message_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 352a84b592..e85bba8823 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -1,90 +1,62 @@ -import datetime import logging import time import click -from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config -from enums.cloud_plan import CloudPlan -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from models.model import ( - App, - Message, - MessageAgentThought, - MessageAnnotation, - MessageChain, - MessageFeedback, - MessageFile, -) -from models.web import SavedMessage -from services.feature_service import FeatureService +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService logger = logging.getLogger(__name__) -@app.celery.task(queue="dataset") +@app.celery.task(queue="retention") def clean_messages(): - click.echo(click.style("Start clean messages.", fg="green")) - start_at = time.perf_counter() - plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( - days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING - ) - while True: - try: - # Main query with join and filter - messages = ( - db.session.query(Message) - .where(Message.created_at < plan_sandbox_clean_message_day) - .order_by(Message.created_at.desc()) - .limit(100) - .all() - ) + """ + Clean expired messages based on clean policy. - except SQLAlchemyError: - raise - if not messages: - break - for message in messages: - app = db.session.query(App).filter_by(id=message.app_id).first() - if not app: - logger.warning( - "Expected App record to exist, but none was found, app_id=%s, message_id=%s", - message.app_id, - message.id, - ) - continue - features_cache_key = f"features:{app.tenant_id}" - plan_cache = redis_client.get(features_cache_key) - if plan_cache is None: - features = FeatureService.get_features(app.tenant_id) - redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) - plan = features.billing.subscription.plan - else: - plan = plan_cache.decode() - if plan == CloudPlan.SANDBOX: - # clean related message - db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete( - synchronize_session=False - ) - db.session.query(Message).where(Message.id == message.id).delete() - db.session.commit() - end_at = time.perf_counter() - click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green")) + This task uses MessagesCleanService to efficiently clean messages in batches. + The behavior depends on BILLING_ENABLED configuration: + - BILLING_ENABLED=True: only delete messages from sandbox tenants (with whitelist/grace period) + - BILLING_ENABLED=False: delete all messages within the time range + """ + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + start_at = time.perf_counter() + + try: + # Create policy based on billing configuration + policy = create_message_clean_policy( + graceful_period_days=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD, + ) + + # Create and run the cleanup service + service = MessagesCleanService.from_days( + policy=policy, + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise diff --git a/api/services/enterprise/workspace_sync.py b/api/services/enterprise/workspace_sync.py new file mode 100644 index 0000000000..acfe325397 --- /dev/null +++ b/api/services/enterprise/workspace_sync.py @@ -0,0 +1,58 @@ +import json +import logging +import uuid +from datetime import UTC, datetime + +from redis import RedisError + +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + +WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue" +WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing" + + +class WorkspaceSyncService: + """Service to publish workspace sync tasks to Redis queue for enterprise backend consumption""" + + @staticmethod + def queue_credential_sync(workspace_id: str, *, source: str) -> bool: + """ + Queue a credential sync task for a newly created workspace. + + This publishes a task to Redis that will be consumed by the enterprise backend + worker to sync credentials with the plugin-manager. + + Args: + workspace_id: The workspace/tenant ID to sync credentials for + source: Source of the sync request (for debugging/tracking) + + Returns: + bool: True if task was queued successfully, False otherwise + """ + try: + task = { + "task_id": str(uuid.uuid4()), + "workspace_id": workspace_id, + "retry_count": 0, + "created_at": datetime.now(UTC).isoformat(), + "source": source, + } + + # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP + redis_client.lpush(WORKSPACE_SYNC_QUEUE, json.dumps(task)) + + logger.info( + "Queued credential sync task for workspace %s, task_id: %s, source: %s", + workspace_id, + task["task_id"], + source, + ) + return True + + except (RedisError, TypeError) as e: + logger.error("Failed to queue credential sync for workspace %s: %s", workspace_id, str(e), exc_info=True) + # Don't raise - we don't want to fail workspace creation if queueing fails + # The scheduled task will catch it later + return False diff --git a/api/services/retention/conversation/messages_clean_policy.py b/api/services/retention/conversation/messages_clean_policy.py new file mode 100644 index 0000000000..6e647b983b --- /dev/null +++ b/api/services/retention/conversation/messages_clean_policy.py @@ -0,0 +1,216 @@ +import datetime +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +@dataclass +class SimpleMessage: + id: str + app_id: str + created_at: datetime.datetime + + +class MessagesCleanPolicy(ABC): + """ + Abstract base class for message cleanup policies. + + A policy determines which messages from a batch should be deleted. + """ + + @abstractmethod + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages and return IDs of messages that should be deleted. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + ... + + +class BillingDisabledPolicy(MessagesCleanPolicy): + """ + Policy for community or enterpriseedition (billing disabled). + + No special filter logic, just return all message ids. + """ + + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + return [msg.id for msg in messages] + + +class BillingSandboxPolicy(MessagesCleanPolicy): + """ + Policy for sandbox plan tenants in cloud edition (billing enabled). + + Filters messages based on sandbox plan expiration rules: + - Skip tenants in the whitelist + - Only delete messages from sandbox plan tenants + - Respect grace period after subscription expiration + - Safe default: if tenant mapping or plan is missing, do NOT delete + """ + + def __init__( + self, + plan_provider: Callable[[Sequence[str]], dict[str, SubscriptionPlan]], + graceful_period_days: int = 21, + tenant_whitelist: Sequence[str] | None = None, + current_timestamp: int | None = None, + ) -> None: + self._graceful_period_days = graceful_period_days + self._tenant_whitelist: Sequence[str] = tenant_whitelist or [] + self._plan_provider = plan_provider + self._current_timestamp = current_timestamp + + def filter_message_ids( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + ) -> Sequence[str]: + """ + Filter messages based on sandbox plan expiration rules. + + Args: + messages: Batch of messages to evaluate + app_to_tenant: Mapping from app_id to tenant_id + + Returns: + List of message IDs that should be deleted + """ + if not messages or not app_to_tenant: + return [] + + # Get unique tenant_ids and fetch subscription plans + tenant_ids = list(set(app_to_tenant.values())) + tenant_plans = self._plan_provider(tenant_ids) + + if not tenant_plans: + return [] + + # Apply sandbox deletion rules + return self._filter_expired_sandbox_messages( + messages=messages, + app_to_tenant=app_to_tenant, + tenant_plans=tenant_plans, + ) + + def _filter_expired_sandbox_messages( + self, + messages: Sequence[SimpleMessage], + app_to_tenant: dict[str, str], + tenant_plans: dict[str, SubscriptionPlan], + ) -> list[str]: + """ + Filter messages that should be deleted based on sandbox plan expiration. + + A message should be deleted if: + 1. It belongs to a sandbox tenant AND + 2. Either: + a) The tenant has no previous subscription (expiration_date == -1), OR + b) The subscription expired more than graceful_period_days ago + + Args: + messages: List of message objects with id and app_id attributes + app_to_tenant: Mapping from app_id to tenant_id + tenant_plans: Mapping from tenant_id to subscription plan info + + Returns: + List of message IDs that should be deleted + """ + current_timestamp = self._current_timestamp + if current_timestamp is None: + current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + sandbox_message_ids: list[str] = [] + graceful_period_seconds = self._graceful_period_days * 24 * 60 * 60 + + for msg in messages: + # Get tenant_id for this message's app + tenant_id = app_to_tenant.get(msg.app_id) + if not tenant_id: + continue + + # Skip tenant messages in whitelist + if tenant_id in self._tenant_whitelist: + continue + + # Get subscription plan for this tenant + tenant_plan = tenant_plans.get(tenant_id) + if not tenant_plan: + continue + + plan = str(tenant_plan["plan"]) + expiration_date = int(tenant_plan["expiration_date"]) + + # Only process sandbox plans + if plan != CloudPlan.SANDBOX: + continue + + # Case 1: No previous subscription (-1 means never had a paid subscription) + if expiration_date == -1: + sandbox_message_ids.append(msg.id) + continue + + # Case 2: Subscription expired beyond grace period + if current_timestamp - expiration_date > graceful_period_seconds: + sandbox_message_ids.append(msg.id) + + return sandbox_message_ids + + +def create_message_clean_policy( + graceful_period_days: int = 21, + current_timestamp: int | None = None, +) -> MessagesCleanPolicy: + """ + Factory function to create the appropriate message clean policy. + + Determines which policy to use based on BILLING_ENABLED configuration: + - If BILLING_ENABLED is True: returns BillingSandboxPolicy + - If BILLING_ENABLED is False: returns BillingDisabledPolicy + + Args: + graceful_period_days: Grace period in days after subscription expiration (default: 21) + current_timestamp: Current Unix timestamp for testing (default: None, uses current time) + """ + if not dify_config.BILLING_ENABLED: + logger.info("create_message_clean_policy: billing disabled, using BillingDisabledPolicy") + return BillingDisabledPolicy() + + # Billing enabled - fetch whitelist from BillingService + tenant_whitelist = BillingService.get_expired_subscription_cleanup_whitelist() + plan_provider = BillingService.get_plan_bulk_with_cache + + logger.info( + "create_message_clean_policy: billing enabled, using BillingSandboxPolicy " + "(graceful_period_days=%s, whitelist=%s)", + graceful_period_days, + tenant_whitelist, + ) + + return BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=graceful_period_days, + tenant_whitelist=tenant_whitelist, + current_timestamp=current_timestamp, + ) diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py new file mode 100644 index 0000000000..3ca5d82860 --- /dev/null +++ b/api/services/retention/conversation/messages_clean_service.py @@ -0,0 +1,334 @@ +import datetime +import logging +import random +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.model import ( + App, + AppAnnotationHitHistory, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.retention.conversation.messages_clean_policy import ( + MessagesCleanPolicy, + SimpleMessage, +) + +logger = logging.getLogger(__name__) + + +class MessagesCleanService: + """ + Service for cleaning expired messages based on retention policies. + + Compatible with non cloud edition (billing disabled): all messages in the time range will be deleted. + If billing is enabled: only sandbox plan tenant messages are deleted (with whitelist and grace period support). + """ + + def __init__( + self, + policy: MessagesCleanPolicy, + end_before: datetime.datetime, + start_from: datetime.datetime | None = None, + batch_size: int = 1000, + dry_run: bool = False, + ) -> None: + """ + Initialize the service with cleanup parameters. + + Args: + policy: The policy that determines which messages to delete + end_before: End time (exclusive) of the range + start_from: Optional start time (inclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + """ + self._policy = policy + self._end_before = end_before + self._start_from = start_from + self._batch_size = batch_size + self._dry_run = dry_run + + @classmethod + def from_time_range( + cls, + policy: MessagesCleanPolicy, + start_from: datetime.datetime, + end_before: datetime.datetime, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages within a specific time range. + + Time range is [start_from, end_before). + + Args: + policy: The policy that determines which messages to delete + start_from: Start time (inclusive) of the range + end_before: End time (exclusive) of the range + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If start_from >= end_before or invalid parameters + """ + if start_from >= end_before: + raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + logger.info( + "clean_messages: start_from=%s, end_before=%s, batch_size=%s, policy=%s", + start_from, + end_before, + batch_size, + policy.__class__.__name__, + ) + + return cls( + policy=policy, + end_before=end_before, + start_from=start_from, + batch_size=batch_size, + dry_run=dry_run, + ) + + @classmethod + def from_days( + cls, + policy: MessagesCleanPolicy, + days: int = 30, + batch_size: int = 1000, + dry_run: bool = False, + ) -> "MessagesCleanService": + """ + Create a service instance for cleaning messages older than specified days. + + Args: + policy: The policy that determines which messages to delete + days: Number of days to look back from now + batch_size: Number of messages to process per batch + dry_run: Whether to perform a dry run (no actual deletion) + + Returns: + MessagesCleanService instance + + Raises: + ValueError: If invalid parameters + """ + if days < 0: + raise ValueError(f"days ({days}) must be greater than or equal to 0") + + if batch_size <= 0: + raise ValueError(f"batch_size ({batch_size}) must be greater than 0") + + end_before = datetime.datetime.now() - datetime.timedelta(days=days) + + logger.info( + "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s", + days, + end_before, + batch_size, + policy.__class__.__name__, + ) + + return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run) + + def run(self) -> dict[str, int]: + """ + Execute the message cleanup operation. + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + return self._clean_messages_by_time_range() + + def _clean_messages_by_time_range(self) -> dict[str, int]: + """ + Clean messages within a time range using cursor-based pagination. + + Time range is [start_from, end_before) + + Steps: + 1. Iterate messages using cursor pagination (by created_at, id) + 2. Query app_id -> tenant_id mapping + 3. Delegate to policy to determine which messages to delete + 4. Batch delete messages and their relations + + Returns: + Dict with statistics: batches, filtered_messages, total_deleted + """ + stats = { + "batches": 0, + "total_messages": 0, + "filtered_messages": 0, + "total_deleted": 0, + } + + # Cursor-based pagination using (created_at, id) to avoid infinite loops + # and ensure proper ordering with time-based filtering + _cursor: tuple[datetime.datetime, str] | None = None + + logger.info( + "clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s", + self._dry_run, + self._start_from, + self._end_before, + ) + + while True: + stats["batches"] += 1 + + # Step 1: Fetch a batch of messages using cursor + with Session(db.engine, expire_on_commit=False) as session: + msg_stmt = ( + select(Message.id, Message.app_id, Message.created_at) + .where(Message.created_at < self._end_before) + .order_by(Message.created_at, Message.id) + .limit(self._batch_size) + ) + + if self._start_from: + msg_stmt = msg_stmt.where(Message.created_at >= self._start_from) + + # Apply cursor condition: (created_at, id) > (last_created_at, last_message_id) + # This translates to: + # created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id) + if _cursor: + # Continuing from previous batch + msg_stmt = msg_stmt.where( + (Message.created_at > _cursor[0]) + | ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1])) + ) + + raw_messages = list(session.execute(msg_stmt).all()) + messages = [ + SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at) + for msg_id, app_id, msg_created_at in raw_messages + ] + + # Track total messages fetched across all batches + stats["total_messages"] += len(messages) + + if not messages: + logger.info("clean_messages (batch %s): no more messages to process", stats["batches"]) + break + + # Update cursor to the last message's (created_at, id) + _cursor = (messages[-1].created_at, messages[-1].id) + + # Step 2: Extract app_ids and query tenant_ids + app_ids = list({msg.app_id for msg in messages}) + + if not app_ids: + logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"]) + continue + + app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids)) + apps = list(session.execute(app_stmt).all()) + + if not apps: + logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"]) + continue + + # Build app_id -> tenant_id mapping + app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps} + + # Step 3: Delegate to policy to determine which messages to delete + message_ids_to_delete = self._policy.filter_message_ids(messages, app_to_tenant) + + if not message_ids_to_delete: + logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"]) + continue + + stats["filtered_messages"] += len(message_ids_to_delete) + + # Step 4: Batch delete messages and their relations + if not self._dry_run: + with Session(db.engine, expire_on_commit=False) as session: + # Delete related records first + self._batch_delete_message_relations(session, message_ids_to_delete) + + # Delete messages + delete_stmt = delete(Message).where(Message.id.in_(message_ids_to_delete)) + delete_result = cast(CursorResult, session.execute(delete_stmt)) + messages_deleted = delete_result.rowcount + session.commit() + + stats["total_deleted"] += messages_deleted + + logger.info( + "clean_messages (batch %s): processed %s messages, deleted %s messages", + stats["batches"], + len(messages), + messages_deleted, + ) + else: + # Log random sample of message IDs that would be deleted (up to 10) + sample_size = min(10, len(message_ids_to_delete)) + sampled_ids = random.sample(list(message_ids_to_delete), sample_size) + + logger.info( + "clean_messages (batch %s, dry_run): would delete %s messages, sampling %s ids:", + stats["batches"], + len(message_ids_to_delete), + sample_size, + ) + for msg_id in sampled_ids: + logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id) + + logger.info( + "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s", + stats["batches"], + stats["total_messages"], + stats["filtered_messages"], + stats["total_deleted"], + ) + + return stats + + @staticmethod + def _batch_delete_message_relations(session: Session, message_ids: Sequence[str]) -> None: + """ + Batch delete all related records for given message IDs. + + Args: + session: Database session + message_ids: List of message IDs to delete relations for + """ + if not message_ids: + return + + # Delete all related records in batch + session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids))) + + session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids))) + + session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids))) + + session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids))) + + session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids))) + + session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids))) + + session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids))) + + session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids))) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..5b6db64c09 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -0,0 +1,1071 @@ +import datetime +import json +import uuid +from decimal import Decimal +from unittest.mock import patch + +import pytest +from faker import Faker + +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.billing_service import BillingService +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +class TestMessagesCleanServiceIntegration: + """Integration tests for MessagesCleanService.run() and _clean_messages_by_time_range().""" + + # Redis cache key prefix from BillingService + PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database before and after each test to ensure isolation.""" + yield + # Clear all test data in correct order (respecting foreign key constraints) + db.session.query(DatasetRetrieverResource).delete() + db.session.query(AppAnnotationHitHistory).delete() + db.session.query(SavedMessage).delete() + db.session.query(MessageFile).delete() + db.session.query(MessageAgentThought).delete() + db.session.query(MessageChain).delete() + db.session.query(MessageAnnotation).delete() + db.session.query(MessageFeedback).delete() + db.session.query(Message).delete() + db.session.query(Conversation).delete() + db.session.query(App).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + @pytest.fixture(autouse=True) + def cleanup_redis(self): + """Clean up Redis cache before each test.""" + # Clear tenant plan cache using BillingService key prefix + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass # Redis might not be available in some test environments + yield + # Clean up after test + try: + keys = redis_client.keys(f"{self.PLAN_CACHE_KEY_PREFIX}*") + if keys: + redis_client.delete(*keys) + except Exception: + pass + + @pytest.fixture + def mock_whitelist(self): + """Mock whitelist to return empty list by default.""" + with patch( + "services.retention.conversation.messages_clean_policy.BillingService.get_expired_subscription_cleanup_whitelist" + ) as mock: + mock.return_value = [] + yield mock + + @pytest.fixture + def mock_billing_enabled(self): + """Mock BILLING_ENABLED to be True.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", True): + yield + + @pytest.fixture + def mock_billing_disabled(self): + """Mock BILLING_ENABLED to be False.""" + with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): + yield + + def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + """Helper to create account and tenant.""" + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.flush() + + tenant = Tenant( + name=fake.company(), + plan=str(plan), + status="normal", + ) + db.session.add(tenant) + db.session.flush() + + tenant_account_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + ) + db.session.add(tenant_account_join) + db.session.commit() + + return account, tenant + + def _create_app(self, tenant, account): + """Helper to create an app.""" + fake = Faker() + + app = App( + tenant_id=tenant.id, + name=fake.company(), + description="Test app", + mode="chat", + enable_site=True, + enable_api=True, + api_rpm=60, + api_rph=3600, + is_demo=False, + is_public=False, + created_by=account.id, + updated_by=account.id, + ) + db.session.add(app) + db.session.commit() + + return app + + def _create_conversation(self, app): + """Helper to create a conversation.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=str(uuid.uuid4()), + model_provider="openai", + model_id="gpt-3.5-turbo", + mode="chat", + name="Test conversation", + inputs={}, + status="normal", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(conversation) + db.session.commit() + + return conversation + + def _create_message(self, app, conversation, created_at=None, with_relations=True): + """Helper to create a message with optional related records.""" + if created_at is None: + created_at = datetime.datetime.now() + + message = Message( + app_id=app.id, + conversation_id=conversation.id, + model_provider="openai", + model_id="gpt-3.5-turbo", + inputs={}, + query="Test query", + answer="Test answer", + message=[{"role": "user", "text": "Test message"}], + message_tokens=10, + message_unit_price=Decimal("0.001"), + answer_tokens=20, + answer_unit_price=Decimal("0.002"), + total_price=Decimal("0.003"), + currency="USD", + from_source="api", + from_account_id=conversation.from_end_user_id, + created_at=created_at, + ) + db.session.add(message) + db.session.flush() + + if with_relations: + self._create_message_relations(message) + + db.session.commit() + return message + + def _create_message_relations(self, message): + """Helper to create all message-related records.""" + # MessageFeedback + feedback = MessageFeedback( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + rating="like", + from_source="api", + from_end_user_id=str(uuid.uuid4()), + ) + db.session.add(feedback) + + # MessageAnnotation + annotation = MessageAnnotation( + app_id=message.app_id, + conversation_id=message.conversation_id, + message_id=message.id, + question="Test question", + content="Test annotation", + account_id=message.from_account_id, + ) + db.session.add(annotation) + + # MessageChain + chain = MessageChain( + message_id=message.id, + type="system", + input=json.dumps({"test": "input"}), + output=json.dumps({"test": "output"}), + ) + db.session.add(chain) + db.session.flush() + + # MessageFile + file = MessageFile( + message_id=message.id, + type="image", + transfer_method="local_file", + url="http://example.com/test.jpg", + belongs_to="user", + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(file) + + # SavedMessage + saved = SavedMessage( + app_id=message.app_id, + message_id=message.id, + created_by_role="end_user", + created_by=str(uuid.uuid4()), + ) + db.session.add(saved) + + db.session.flush() + + # AppAnnotationHitHistory + hit = AppAnnotationHitHistory( + app_id=message.app_id, + annotation_id=annotation.id, + message_id=message.id, + source="annotation", + question="Test question", + account_id=message.from_account_id, + score=0.9, + annotation_question="Test annotation question", + annotation_content="Test annotation content", + ) + db.session.add(hit) + + # DatasetRetrieverResource + resource = DatasetRetrieverResource( + message_id=message.id, + position=1, + dataset_id=str(uuid.uuid4()), + dataset_name="Test dataset", + document_id=str(uuid.uuid4()), + document_name="Test document", + data_source_type="upload_file", + segment_id=str(uuid.uuid4()), + score=0.9, + content="Test content", + hit_count=1, + word_count=10, + segment_position=1, + index_node_hash="test_hash", + retriever_from="dataset", + created_by=message.from_account_id, + ) + db.session.add(resource) + + def test_billing_disabled_deletes_all_messages_in_time_range( + self, db_session_with_containers, mock_billing_disabled + ): + """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" + # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: in-range (should be deleted) and out-of-range (should be kept) + in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) + out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) + + in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg_id = in_range_msg.id + + out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg_id = out_of_range_msg.id + + # Act - create_message_clean_policy should return BillingDisabledPolicy + policy = create_message_clean_policy() + + assert isinstance(policy, BillingDisabledPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 0, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 0, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 1 # Only in-range message fetched + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # In-range message deleted + assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + # Out-of-range message kept + assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + + # Related records of in-range message deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 + assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + # Related records of out-of-range message kept + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + + def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning when there are no messages to delete (B1).""" + # Arrange + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + start_from = datetime.datetime.now() - datetime.timedelta(days=60) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - loop runs once to check, finds nothing + assert stats["batches"] == 1 + assert stats["total_messages"] == 0 + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning with mixed sandbox and paid tenants (B2).""" + # Arrange - Create sandbox tenants with expired messages + sandbox_tenants = [] + sandbox_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + sandbox_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 3 expired messages per sandbox tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + sandbox_message_ids.append(msg.id) + + # Create paid tenants with expired messages (should NOT be deleted) + paid_tenants = [] + paid_message_ids = [] + for i in range(2): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + paid_tenants.append(tenant) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 2 expired messages per paid tenant + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + for j in range(2): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + paid_message_ids.append(msg.id) + + # Mock billing service - return plan and expiration_date + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period + + plan_map = {} + for tenant in sandbox_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_15_days_ago, + } + for tenant in paid_tenants: + plan_map[tenant.id] = { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=7) + + assert isinstance(policy, BillingSandboxPolicy) + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 10 # 2 sandbox * 3 + 2 paid * 2 + assert stats["filtered_messages"] == 6 # 2 sandbox tenants * 3 messages + assert stats["total_deleted"] == 6 + + # Only sandbox messages should be deleted + assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + # Paid messages should remain + assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + + # Related records of sandbox messages should be deleted + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 + assert ( + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + == 0 + ) + + def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cursor pagination works correctly across multiple batches (B3).""" + # Arrange - Create sandbox tenant with messages that will span multiple batches + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create 10 expired messages with different timestamps + base_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(10): + msg = self._create_message( + app, + conv, + created_at=base_date + datetime.timedelta(hours=i), + with_relations=False, # Skip relations for speed + ) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Use small batch size to trigger multiple batches + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=3, # Small batch size to test pagination + ) + stats = service.run() + + # 5 batches for 10 messages with batch_size=3, the last batch is empty + assert stats["batches"] == 5 + assert stats["total_messages"] == 10 + assert stats["filtered_messages"] == 10 + assert stats["total_deleted"] == 10 + + # All messages should be deleted + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + + def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test dry_run mode does not delete messages (B4).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create expired messages + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + message_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + message_ids.append(msg.id) + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + dry_run=True, # Dry run mode + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 # Messages identified + assert stats["total_deleted"] == 0 # But NOT deleted + + # All messages should still exist + assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + # Related records should also still exist + assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + + def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test when billing returns partial data, unknown tenants are preserved (B5).""" + # Arrange - Create 3 tenants + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service to return partial data + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing + partial_plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year + }, + # tenants_data[2] is missing from response + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = partial_plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant[0]'s message should be deleted + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Check which messages were deleted + assert ( + db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + ) # Sandbox tenant's message deleted + + assert ( + db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + ) # Professional tenant's message preserved + + assert ( + db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + ) # Unknown tenant's message preserved (safe default) + + def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test when billing returns empty data, skip deletion entirely (B6).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date) + msg_id = msg.id + db.session.commit() + + # Mock billing service to return empty data (simulating failure/no data scenario) + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = {} # Empty response, tenant plan unknown + + # Act - Should not raise exception, just skip deletion + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted when plan is unknown + assert stats["total_messages"] == 1 + assert stats["filtered_messages"] == 0 # Cannot determine sandbox messages + assert stats["total_deleted"] == 0 + + # Message should still exist (safe default - don't delete if plan is unknown) + assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + + def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create messages: before range, in range, after range + msg_before = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from + with_relations=False, + ) + msg_before_id = msg_before.id + + msg_at_start = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) + with_relations=False, + ) + msg_at_start_id = msg_at_start.id + + msg_in_range = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range + with_relations=False, + ) + msg_in_range_id = msg_in_range.id + + msg_at_end = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) + with_relations=False, + ) + msg_at_end_id = msg_at_end.id + + msg_after = self._create_message( + app, + conv, + created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before + with_relations=False, + ) + msg_after_id = msg_after.id + + db.session.commit() + + # Mock billing service + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, # No previous subscription + } + } + + # Act - Clean with specific time range [2024-01-10, 2024-01-20) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime(2024, 1, 10, 12, 0, 0), + end_before=datetime.datetime(2024, 1, 20, 12, 0, 0), + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages in [start_from, end_before) should be deleted + assert stats["total_messages"] == 2 # Only in-range messages fetched + assert stats["filtered_messages"] == 2 # msg_at_start and msg_in_range + assert stats["total_deleted"] == 2 + + # Verify specific messages using stored IDs + # Before range, kept + assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + # At start (inclusive), deleted + assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + # In range, deleted + assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + # At end (exclusive), kept + assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + # After range, kept + assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + + def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test cleaning with different graceful period scenarios (B8).""" + # Arrange - Create 5 different tenants with different plan and expiration scenarios + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + graceful_period = 8 # Use 8 days for this test + + # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) + # Should NOT be deleted + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1_id = msg1.id + expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period + + # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) + # Should be deleted + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + msg2_id = msg2.id + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period + + # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) + # Should be deleted + account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app3 = self._create_app(tenant3, account3) + conv3 = self._create_conversation(app3) + msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + msg3_id = msg3.id + + # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) + # Should NOT be deleted + account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(tenant4, account4) + conv4 = self._create_conversation(app4) + msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + msg4_id = msg4.id + future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year + + # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) + # Should NOT be deleted (boundary is exclusive: > graceful_period) + account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app5 = self._create_app(tenant5, account5) + conv5 = self._create_conversation(app5) + msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + msg5_id = msg5.id + expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary + + db.session.commit() + + # Mock billing service with all scenarios + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_5_days_ago, + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, + }, + tenant3.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenant4.id: { + "plan": CloudPlan.PROFESSIONAL, + "expiration_date": future_expiration, + }, + tenant5.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_exactly_8_days_ago, + }, + } + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy( + graceful_period_days=graceful_period, + current_timestamp=now_timestamp, # Use fixed timestamp for deterministic behavior + ) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only messages from scenario 2 and 3 should be deleted + assert stats["total_messages"] == 5 # 5 tenants * 1 message + assert stats["filtered_messages"] == 2 + assert stats["total_deleted"] == 2 + + # Verify each scenario using saved IDs + assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted + assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted + assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept + assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + + def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test that whitelisted tenants' messages are not deleted (B9).""" + # Arrange - Create 3 sandbox tenants with expired messages + tenants_data = [] + for i in range(3): + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + + tenants_data.append( + { + "tenant": tenant, + "message_id": msg.id, + } + ) + + # Mock billing service - all tenants are sandbox with no subscription + plan_map = { + tenants_data[0]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[1]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + tenants_data[2]["tenant"].id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + }, + } + + # Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not + whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - Only tenant2's message should be deleted (not whitelisted) + assert stats["total_messages"] == 3 # 3 tenants * 1 message + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 1 + + # Verify tenant0's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + + # Verify tenant1's message still exists (whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + + # Verify tenant2's message was deleted (not whitelisted) + assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + + def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + """Test from_days correctly cleans messages older than N days (B11).""" + # Arrange + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + # Create old messages (should be deleted - older than 30 days) + old_date = datetime.datetime.now() - datetime.timedelta(days=45) + old_msg_ids = [] + for i in range(3): + msg = self._create_message( + app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + ) + old_msg_ids.append(msg.id) + + # Create recent messages (should be kept - newer than 30 days) + recent_date = datetime.datetime.now() - datetime.timedelta(days=15) + recent_msg_ids = [] + for i in range(2): + msg = self._create_message( + app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + ) + recent_msg_ids.append(msg.id) + + db.session.commit() + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Act - Use from_days to clean messages older than 30 days + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_days( + policy=policy, + days=30, + batch_size=100, + ) + stats = service.run() + + # Assert + assert stats["total_messages"] == 3 # Only old messages in range + assert stats["filtered_messages"] == 3 # Only old messages + assert stats["total_deleted"] == 3 + + # Old messages deleted + assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + # Recent messages kept + assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + + def test_whitelist_precedence_over_grace_period( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that whitelist takes precedence over grace period logic.""" + # Arrange - Create 2 sandbox tenants + now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) + + # Tenant1: whitelisted, expired beyond grace period + account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app1 = self._create_app(tenant1, account1) + conv1 = self._create_conversation(app1) + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace + + # Tenant2: not whitelisted, within grace period + account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app2 = self._create_app(tenant2, account2) + conv2 = self._create_conversation(app2) + msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace + + # Mock billing service + plan_map = { + tenant1.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_30_days_ago, # Beyond grace period + }, + tenant2.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": expired_10_days_ago, # Within grace period + }, + } + + # Setup whitelist - only tenant1 is whitelisted + whitelist = [tenant1.id] + mock_whitelist.return_value = whitelist + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - No messages should be deleted + # tenant1: whitelisted (protected even though beyond grace period) + # tenant2: within grace period (not eligible for deletion) + assert stats["total_messages"] == 2 # 2 tenants * 1 message + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + # Verify both messages still exist + assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + + def test_empty_whitelist_deletes_eligible_messages( + self, db_session_with_containers, mock_billing_enabled, mock_whitelist + ): + """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" + # Arrange - Create sandbox tenant with expired messages + account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + app = self._create_app(tenant, account) + conv = self._create_conversation(app) + + expired_date = datetime.datetime.now() - datetime.timedelta(days=35) + msg_ids = [] + for i in range(3): + msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg_ids.append(msg.id) + + # Mock billing service + plan_map = { + tenant.id: { + "plan": CloudPlan.SANDBOX, + "expiration_date": -1, + } + } + + # Setup empty whitelist (default behavior from fixture) + mock_whitelist.return_value = [] + + with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: + mock_billing.return_value = plan_map + + # Act + end_before = datetime.datetime.now() - datetime.timedelta(days=30) + policy = create_message_clean_policy(graceful_period_days=21) + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=datetime.datetime.now() - datetime.timedelta(days=60), + end_before=end_before, + batch_size=100, + ) + stats = service.run() + + # Assert - All messages should be deleted (no whitelist protection) + assert stats["total_messages"] == 3 + assert stats["filtered_messages"] == 3 + assert stats["total_deleted"] == 3 + + # Verify all messages were deleted + assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py new file mode 100644 index 0000000000..3b619195c7 --- /dev/null +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -0,0 +1,627 @@ +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from enums.cloud_plan import CloudPlan +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, + BillingSandboxPolicy, + SimpleMessage, + create_message_clean_policy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +def make_simple_message(msg_id: str, app_id: str) -> SimpleMessage: + """Helper to create a SimpleMessage with a fixed created_at timestamp.""" + return SimpleMessage(id=msg_id, app_id=app_id, created_at=datetime.datetime(2024, 1, 1)) + + +def make_plan_provider(tenant_plans: dict) -> MagicMock: + """Helper to create a mock plan_provider that returns the given tenant_plans.""" + provider = MagicMock() + provider.return_value = tenant_plans + return provider + + +class TestBillingSandboxPolicyFilterMessageIds: + """Unit tests for BillingSandboxPolicy.filter_message_ids method.""" + + # Fixed timestamp for deterministic tests + CURRENT_TIMESTAMP = 1000000 + GRACEFUL_PERIOD_DAYS = 8 + GRACEFUL_PERIOD_SECONDS = GRACEFUL_PERIOD_DAYS * 24 * 60 * 60 + + def test_missing_tenant_mapping_excluded(self): + """Test that messages with missing app-to-tenant mapping are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {} # No mapping + tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}} + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_missing_tenant_plan_excluded(self): + """Test that messages with missing tenant plan are excluded (safe default).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = {} # No plans + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_non_sandbox_plan_excluded(self): + """Test that messages from non-sandbox plans (PROFESSIONAL/TEAM) are excluded.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.TEAM, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, # Only this one + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg3 (sandbox tenant) should be included + assert set(result) == {"msg3"} + + def test_whitelist_skip(self): + """Test that whitelisted tenants are excluded even if sandbox + expired.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), # Whitelisted - excluded + make_simple_message("msg2", "app2"), # Not whitelisted - included + make_simple_message("msg3", "app3"), # Whitelisted - excluded + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant1", "tenant3"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg2 should be included + assert set(result) == {"msg2"} + + def test_no_previous_subscription_included(self): + """Test that messages with expiration_date=-1 (no previous subscription) are included.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all messages should be included + assert set(result) == {"msg1", "msg2"} + + def test_within_grace_period_excluded(self): + """Test that messages within grace period are excluded.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_1_day_ago = now - (1 * 24 * 60 * 60) + expired_5_days_ago = now - (5 * 24 * 60 * 60) + expired_7_days_ago = now - (7 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant3"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, # 8 days + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all within 8-day grace period, none should be included + assert list(result) == [] + + def test_exactly_at_boundary_excluded(self): + """Test that messages exactly at grace period boundary are excluded (code uses >).""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_exactly_8_days_ago = now - self.GRACEFUL_PERIOD_SECONDS # Exactly at boundary + + messages = [make_simple_message("msg1", "app1")] + app_to_tenant = {"app1": "tenant1"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - exactly at boundary (==) should be excluded (code uses >) + assert list(result) == [] + + def test_beyond_grace_period_included(self): + """Test that messages beyond grace period are included.""" + # Arrange + now = self.CURRENT_TIMESTAMP + expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond 8-day grace + expired_30_days_ago = now - (30 * 24 * 60 * 60) # Well beyond + + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago}, + } + plan_provider = make_plan_provider(tenant_plans) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - both beyond grace period, should be included + assert set(result) == {"msg1", "msg2"} + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + plan_provider = make_plan_provider({"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + def test_plan_provider_called_with_correct_tenant_ids(self): + """Test that plan_provider is called with correct tenant_ids.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2", "app3": "tenant1"} # tenant1 appears twice + plan_provider = make_plan_provider({}) + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + current_timestamp=self.CURRENT_TIMESTAMP, + ) + + # Act + policy.filter_message_ids(messages, app_to_tenant) + + # Assert - plan_provider should be called once with unique tenant_ids + plan_provider.assert_called_once() + called_tenant_ids = set(plan_provider.call_args[0][0]) + assert called_tenant_ids == {"tenant1", "tenant2"} + + def test_complex_mixed_scenario(self): + """Test complex scenario with mixed plans, expirations, whitelist, and missing mappings.""" + # Arrange + now = self.CURRENT_TIMESTAMP + sandbox_expired_old = now - (15 * 24 * 60 * 60) # Beyond grace + sandbox_expired_recent = now - (3 * 24 * 60 * 60) # Within grace + future_expiration = now + (30 * 24 * 60 * 60) + + messages = [ + make_simple_message("msg1", "app1"), # Sandbox, no subscription - included + make_simple_message("msg2", "app2"), # Sandbox, expired old - included + make_simple_message("msg3", "app3"), # Sandbox, within grace - excluded + make_simple_message("msg4", "app4"), # Team plan, active - excluded + make_simple_message("msg5", "app5"), # No tenant mapping - excluded + make_simple_message("msg6", "app6"), # No plan info - excluded + make_simple_message("msg7", "app7"), # Sandbox, expired old, whitelisted - excluded + ] + app_to_tenant = { + "app1": "tenant1", + "app2": "tenant2", + "app3": "tenant3", + "app4": "tenant4", + "app6": "tenant6", # Has mapping but no plan + "app7": "tenant7", + # app5 has no mapping + } + tenant_plans = { + "tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + "tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent}, + "tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration}, + "tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old}, + # tenant6 has no plan + } + plan_provider = make_plan_provider(tenant_plans) + tenant_whitelist = ["tenant7"] + + policy = BillingSandboxPolicy( + plan_provider=plan_provider, + graceful_period_days=self.GRACEFUL_PERIOD_DAYS, + tenant_whitelist=tenant_whitelist, + current_timestamp=now, + ) + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - only msg1 and msg2 should be included + assert set(result) == {"msg1", "msg2"} + + +class TestBillingDisabledPolicyFilterMessageIds: + """Unit tests for BillingDisabledPolicy.filter_message_ids method.""" + + def test_returns_all_message_ids(self): + """Test that all message IDs are returned (order-preserving).""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + make_simple_message("msg3", "app3"), + ] + app_to_tenant = {"app1": "tenant1", "app2": "tenant2"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs returned in order + assert list(result) == ["msg1", "msg2", "msg3"] + + def test_ignores_app_to_tenant(self): + """Test that app_to_tenant mapping is ignored.""" + # Arrange + messages = [ + make_simple_message("msg1", "app1"), + make_simple_message("msg2", "app2"), + ] + app_to_tenant: dict[str, str] = {} # Empty - should be ignored + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert - all message IDs still returned + assert list(result) == ["msg1", "msg2"] + + def test_empty_messages_returns_empty(self): + """Test that empty messages returns empty list.""" + # Arrange + messages: list[SimpleMessage] = [] + app_to_tenant = {"app1": "tenant1"} + + policy = BillingDisabledPolicy() + + # Act + result = policy.filter_message_ids(messages, app_to_tenant) + + # Assert + assert list(result) == [] + + +class TestCreateMessageCleanPolicy: + """Unit tests for create_message_clean_policy factory function.""" + + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_disabled_returns_billing_disabled_policy(self, mock_config): + """Test that BILLING_ENABLED=False returns BillingDisabledPolicy.""" + # Arrange + mock_config.BILLING_ENABLED = False + + # Act + policy = create_message_clean_policy(graceful_period_days=21) + + # Assert + assert isinstance(policy, BillingDisabledPolicy) + + @patch("services.retention.conversation.messages_clean_policy.BillingService") + @patch("services.retention.conversation.messages_clean_policy.dify_config") + def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service): + """Test that BillingSandboxPolicy is created with correct internal values.""" + # Arrange + mock_config.BILLING_ENABLED = True + whitelist = ["tenant1", "tenant2"] + mock_billing_service.get_expired_subscription_cleanup_whitelist.return_value = whitelist + mock_plan_provider = MagicMock() + mock_billing_service.get_plan_bulk_with_cache = mock_plan_provider + + # Act + policy = create_message_clean_policy(graceful_period_days=14, current_timestamp=1234567) + + # Assert + mock_billing_service.get_expired_subscription_cleanup_whitelist.assert_called_once() + assert isinstance(policy, BillingSandboxPolicy) + assert policy._graceful_period_days == 14 + assert list(policy._tenant_whitelist) == whitelist + assert policy._plan_provider == mock_plan_provider + assert policy._current_timestamp == 1234567 + + +class TestMessagesCleanServiceFromTimeRange: + """Unit tests for MessagesCleanService.from_time_range factory method.""" + + def test_start_from_end_before_raises_value_error(self): + """Test that start_from == end_before raises ValueError.""" + policy = BillingDisabledPolicy() + + # Arrange + same_time = datetime.datetime(2024, 1, 1, 12, 0, 0) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=same_time, + end_before=same_time, + ) + + # Arrange + start_from = datetime.datetime(2024, 12, 31) + end_before = datetime.datetime(2024, 1, 1) + + # Act & Assert + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=0, + ) + + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=-100, + ) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 12, 31, 23, 59, 59) + policy = BillingDisabledPolicy() + batch_size = 500 + dry_run = True + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from == start_from + assert service._end_before == end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 2, 1) + policy = BillingDisabledPolicy() + + # Act + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + ) + + # Assert + assert service._batch_size == 1000 # default + assert service._dry_run is False # default + + +class TestMessagesCleanServiceFromDays: + """Unit tests for MessagesCleanService.from_days factory method.""" + + def test_days_raises_value_error(self): + """Test that days < 0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy=policy, days=-1) + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy, days=0) + + # Assert + assert service._end_before == fixed_now + + def test_batch_size_raises_value_error(self): + """Test that batch_size=0 raises ValueError.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=0) + + # Act & Assert + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy=policy, days=30, batch_size=-500) + + def test_valid_params_creates_instance(self): + """Test that valid parameters create a correctly configured instance.""" + # Arrange + policy = BillingDisabledPolicy() + days = 90 + batch_size = 500 + dry_run = True + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days( + policy=policy, + days=days, + batch_size=batch_size, + dry_run=dry_run, + ) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=days) + assert isinstance(service, MessagesCleanService) + assert service._policy is policy + assert service._start_from is None + assert service._end_before == expected_end_before + assert service._batch_size == batch_size + assert service._dry_run == dry_run + + def test_default_params(self): + """Test that default parameters are applied correctly.""" + # Arrange + policy = BillingDisabledPolicy() + + # Act + with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime: + fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) + mock_datetime.datetime.now.return_value = fixed_now + mock_datetime.timedelta = datetime.timedelta + + service = MessagesCleanService.from_days(policy=policy) + + # Assert + expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30 + assert service._end_before == expected_end_before + assert service._batch_size == 1000 # default + assert service._dry_run is False # default diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 5ffa87375c..7ea784baa3 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -21,7 +21,6 @@ import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/ import FileUploadSetting from '@/app/components/workflow/nodes/_base/components/file-upload-setting' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import { ChangeType, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types' -import { DEFAULT_VALUE_MAX_LEN } from '@/config' import ConfigContext from '@/context/debug-configuration' import { AppModeEnum, TransferMethod } from '@/types/app' import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var' @@ -198,8 +197,6 @@ const ConfigModal: FC = ({ if (type === InputVarType.multiFiles) draft.max_length = DEFAULT_FILE_UPLOAD_SETTING.max_length } - if (type === InputVarType.paragraph) - draft.max_length = DEFAULT_VALUE_MAX_LEN }) setTempPayload(newPayload) }, [tempPayload]) diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 4a38fc92a6..1a8810f7cd 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -15,7 +15,6 @@ import Confirm from '@/app/components/base/confirm' import Toast from '@/app/components/base/toast' import Tooltip from '@/app/components/base/tooltip' import { InputVarType } from '@/app/components/workflow/types' -import { DEFAULT_VALUE_MAX_LEN } from '@/config' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' import { useModalContext } from '@/context/modal-context' @@ -58,8 +57,6 @@ const buildPromptVariableFromInput = (payload: InputVar): PromptVariable => { key: variable, name: label as string, } - if (payload.type === InputVarType.textInput) - nextItem.max_length = nextItem.max_length || DEFAULT_VALUE_MAX_LEN if (payload.type !== InputVarType.select) delete nextItem.options diff --git a/web/app/components/app/configuration/debug/chat-user-input.tsx b/web/app/components/app/configuration/debug/chat-user-input.tsx index 11189751e0..3f9fdc32be 100644 --- a/web/app/components/app/configuration/debug/chat-user-input.tsx +++ b/web/app/components/app/configuration/debug/chat-user-input.tsx @@ -7,7 +7,6 @@ import Input from '@/app/components/base/input' import Select from '@/app/components/base/select' import Textarea from '@/app/components/base/textarea' import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' -import { DEFAULT_VALUE_MAX_LEN } from '@/config' import ConfigContext from '@/context/debug-configuration' import { cn } from '@/utils/classnames' @@ -88,7 +87,7 @@ const ChatUserInput = ({ onChange={(e) => { handleInputValueChange(key, e.target.value) }} placeholder={name} autoFocus={index === 0} - maxLength={max_length || DEFAULT_VALUE_MAX_LEN} + maxLength={max_length} /> )} {type === 'paragraph' && ( @@ -115,7 +114,7 @@ const ChatUserInput = ({ onChange={(e) => { handleInputValueChange(key, e.target.value) }} placeholder={name} autoFocus={index === 0} - maxLength={max_length || DEFAULT_VALUE_MAX_LEN} + maxLength={max_length} /> )} {type === 'checkbox' && ( diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx index 9b61b3c7aa..613efb8710 100644 --- a/web/app/components/app/configuration/prompt-value-panel/index.tsx +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -20,7 +20,6 @@ import Select from '@/app/components/base/select' import Textarea from '@/app/components/base/textarea' import Tooltip from '@/app/components/base/tooltip' import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' -import { DEFAULT_VALUE_MAX_LEN } from '@/config' import ConfigContext from '@/context/debug-configuration' import { AppModeEnum, ModelModeType } from '@/types/app' import { cn } from '@/utils/classnames' @@ -142,7 +141,7 @@ const PromptValuePanel: FC = ({ onChange={(e) => { handleInputValueChange(key, e.target.value) }} placeholder={name} autoFocus={index === 0} - maxLength={max_length || DEFAULT_VALUE_MAX_LEN} + maxLength={max_length} /> )} {type === 'paragraph' && ( @@ -170,7 +169,7 @@ const PromptValuePanel: FC = ({ onChange={(e) => { handleInputValueChange(key, e.target.value) }} placeholder={name} autoFocus={index === 0} - maxLength={max_length || DEFAULT_VALUE_MAX_LEN} + maxLength={max_length} /> )} {type === 'checkbox' && ( diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 095ed3f696..84150ad480 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -12,7 +12,6 @@ import { useDebounceFn } from 'ahooks' import dynamic from 'next/dynamic' import { useRouter, - useSearchParams, } from 'next/navigation' import { parseAsString, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' @@ -29,7 +28,6 @@ import { CheckModal } from '@/hooks/use-pay' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' -import { isServer } from '@/utils/client' import AppCard from './app-card' import { AppCardSkeleton } from './app-card-skeleton' import Empty from './empty' @@ -59,7 +57,6 @@ const List = () => { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() const router = useRouter() - const searchParams = useSearchParams() const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [activeTab, setActiveTab] = useQueryState( @@ -67,33 +64,6 @@ const List = () => { parseAsString.withDefault('all').withOptions({ history: 'push' }), ) - // valid tabs for apps list; anything else should fallback to 'all' - - // 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all - useEffect(() => { - // avoid running on server - if (isServer) - return - const mode = searchParams.get('mode') - if (!mode) - return - const url = new URL(window.location.href) - url.searchParams.delete('mode') - if (validTabs.has(mode)) { - // migrate to category key - url.searchParams.set('category', mode) - } - else { - url.searchParams.set('category', 'all') - } - router.replace(url.pathname + url.search) - }, [router, searchParams]) - - // 2) If category has an invalid value (e.g., 'discover'), reset to 'all' - useEffect(() => { - if (!validTabs.has(activeTab)) - setActiveTab('all') - }, [activeTab, setActiveTab]) const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState() const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe) const [tagFilterValue, setTagFilterValue] = useState(tagIDs) diff --git a/web/app/components/header/nav/index.tsx b/web/app/components/header/nav/index.tsx index 83e75b8513..2edc64486e 100644 --- a/web/app/components/header/nav/index.tsx +++ b/web/app/components/header/nav/index.tsx @@ -2,9 +2,9 @@ import type { INavSelectorProps } from './nav-selector' import Link from 'next/link' -import { usePathname, useSearchParams, useSelectedLayoutSegment } from 'next/navigation' +import { useSelectedLayoutSegment } from 'next/navigation' import * as React from 'react' -import { useEffect, useState } from 'react' +import { useState } from 'react' import { useStore as useAppStore } from '@/app/components/app/store' import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' import { cn } from '@/utils/classnames' @@ -36,14 +36,6 @@ const Nav = ({ const [hovered, setHovered] = useState(false) const segment = useSelectedLayoutSegment() const isActivated = Array.isArray(activeSegment) ? activeSegment.includes(segment!) : segment === activeSegment - const pathname = usePathname() - const searchParams = useSearchParams() - const [linkLastSearchParams, setLinkLastSearchParams] = useState('') - - useEffect(() => { - if (pathname === link) - setLinkLastSearchParams(searchParams.toString()) - }, [pathname, searchParams]) return (
- +
{ // Don't clear state if opening in new tab/window diff --git a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/hooks.ts b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/hooks.ts index 3820d5f1b8..80aa879b8f 100644 --- a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/hooks.ts +++ b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/hooks.ts @@ -6,7 +6,6 @@ import { useTranslation } from 'react-i18next' import { useFileSizeLimit } from '@/app/components/base/file-uploader/hooks' import { InputFieldType } from '@/app/components/base/form/form-scenarios/input-field/types' import { DEFAULT_FILE_UPLOAD_SETTING } from '@/app/components/workflow/constants' -import { DEFAULT_VALUE_MAX_LEN } from '@/config' import { PipelineInputVarType } from '@/models/pipeline' import { useFileUploadConfig } from '@/service/use-common' import { formatFileSize } from '@/utils/format' @@ -87,8 +86,6 @@ export const useConfigurations = (props: { if (type === PipelineInputVarType.multiFiles) setFieldValue('maxLength', DEFAULT_FILE_UPLOAD_SETTING.max_length) } - if (type === PipelineInputVarType.paragraph) - setFieldValue('maxLength', DEFAULT_VALUE_MAX_LEN) }, [setFieldValue]) const handleVariableNameBlur = useCallback((value: string) => { diff --git a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.spec.tsx b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.spec.tsx index 0470bd4c68..48df13acb2 100644 --- a/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.spec.tsx +++ b/web/app/components/rag-pipeline/components/panel/input-field/editor/form/index.spec.tsx @@ -779,27 +779,6 @@ describe('useConfigurations', () => { expect(mockSetFieldValue).toHaveBeenCalledWith('maxLength', expect.any(Number)) }) - it('should call setFieldValue when type changes to paragraph', () => { - // Arrange - const mockGetFieldValue = vi.fn() - const mockSetFieldValue = vi.fn() - - const { result } = renderHookWithProviders(() => - useConfigurations({ - getFieldValue: mockGetFieldValue, - setFieldValue: mockSetFieldValue, - supportFile: false, - }), - ) - - // Act - const typeConfig = result.current.find(config => config.variable === 'type') - typeConfig?.listeners?.onChange?.(createMockEvent(PipelineInputVarType.paragraph)) - - // Assert - expect(mockSetFieldValue).toHaveBeenCalledWith('maxLength', 48) // DEFAULT_VALUE_MAX_LEN - }) - it('should set label from variable name on blur when label is empty', () => { // Arrange const mockGetFieldValue = vi.fn().mockReturnValue('') diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index b793a03ce7..509687e245 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -26,7 +26,7 @@ import DifyLogo from '@/app/components/base/logo/dify-logo' import Toast from '@/app/components/base/toast' import Res from '@/app/components/share/text-generation/result' import RunOnce from '@/app/components/share/text-generation/run-once' -import { appDefaultIconBackground, BATCH_CONCURRENCY, DEFAULT_VALUE_MAX_LEN } from '@/config' +import { appDefaultIconBackground, BATCH_CONCURRENCY } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' import { useWebAppStore } from '@/context/web-app-context' import { useAppFavicon } from '@/hooks/use-app-favicon' @@ -256,11 +256,10 @@ const TextGeneration: FC = ({ promptConfig?.prompt_variables.forEach((varItem, varIndex) => { if (errorRowIndex !== 0) return - if (varItem.type === 'string') { - const maxLen = varItem.max_length || DEFAULT_VALUE_MAX_LEN - if (item[varIndex].length > maxLen) { + if (varItem.type === 'string' && varItem.max_length) { + if (item[varIndex].length > varItem.max_length) { moreThanMaxLengthVarName = varItem.name - maxLength = maxLen + maxLength = varItem.max_length errorRowIndex = index + 1 return } diff --git a/web/app/components/share/text-generation/run-once/index.spec.tsx b/web/app/components/share/text-generation/run-once/index.spec.tsx index 8882253d0e..ea5ce3c902 100644 --- a/web/app/components/share/text-generation/run-once/index.spec.tsx +++ b/web/app/components/share/text-generation/run-once/index.spec.tsx @@ -236,4 +236,46 @@ describe('RunOnce', () => { const stopButton = screen.getByTestId('stop-button') expect(stopButton).toBeDisabled() }) + + describe('maxLength behavior', () => { + it('should not have maxLength attribute when max_length is not set', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'textInput', + name: 'Text Input', + type: 'string', + // max_length is not set + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + const input = screen.getByPlaceholderText('Text Input') + expect(input).not.toHaveAttribute('maxLength') + }) + + it('should have maxLength attribute when max_length is set', async () => { + const promptConfig: PromptConfig = { + prompt_template: 'template', + prompt_variables: [ + createPromptVariable({ + key: 'textInput', + name: 'Text Input', + type: 'string', + max_length: 100, + }), + ], + } + const { onInputsChange } = setup({ promptConfig, visionConfig: { ...baseVisionConfig, enabled: false } }) + await waitFor(() => { + expect(onInputsChange).toHaveBeenCalled() + }) + const input = screen.getByPlaceholderText('Text Input') + expect(input).toHaveAttribute('maxLength', '100') + }) + }) }) diff --git a/web/app/components/share/text-generation/run-once/index.tsx b/web/app/components/share/text-generation/run-once/index.tsx index b8193fd944..ca29ce1a98 100644 --- a/web/app/components/share/text-generation/run-once/index.tsx +++ b/web/app/components/share/text-generation/run-once/index.tsx @@ -19,7 +19,6 @@ import Textarea from '@/app/components/base/textarea' import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' -import { DEFAULT_VALUE_MAX_LEN } from '@/config' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { cn } from '@/utils/classnames' @@ -140,7 +139,7 @@ const RunOnce: FC = ({ placeholder={item.name} value={inputs[item.key]} onChange={(e: ChangeEvent) => { handleInputsChange({ ...inputsRef.current, [item.key]: e.target.value }) }} - maxLength={item.max_length || DEFAULT_VALUE_MAX_LEN} + maxLength={item.max_length} /> )} {item.type === 'paragraph' && ( diff --git a/web/config/index.ts b/web/config/index.ts index b804629048..08ce14b264 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -208,7 +208,6 @@ export const VAR_ITEM_TEMPLATE = { key: '', name: '', type: 'string', - max_length: DEFAULT_VALUE_MAX_LEN, required: true, } @@ -216,7 +215,6 @@ export const VAR_ITEM_TEMPLATE_IN_WORKFLOW = { variable: '', label: '', type: InputVarType.textInput, - max_length: DEFAULT_VALUE_MAX_LEN, required: true, options: [], } @@ -225,7 +223,6 @@ export const VAR_ITEM_TEMPLATE_IN_PIPELINE = { variable: '', label: '', type: PipelineInputVarType.textInput, - max_length: DEFAULT_VALUE_MAX_LEN, required: true, options: [], } diff --git a/web/package.json b/web/package.json index 000862204b..5ca90c75ea 100644 --- a/web/package.json +++ b/web/package.json @@ -28,9 +28,10 @@ "build:docker": "next build && node scripts/optimize-standalone.js", "start": "node ./scripts/copy-and-start.mjs", "lint": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache", - "lint:fix": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --fix", - "lint:quiet": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --quiet", - "lint:complexity": "eslint --cache --cache-location node_modules/.cache/eslint/.eslint-cache --rule 'complexity: [error, {max: 15}]' --quiet", + "lint:fix": "pnpm lint --fix", + "lint:quiet": "pnpm lint --quiet", + "lint:complexity": "pnpm lint --rule 'complexity: [error, {max: 15}]' --quiet", + "lint:report": "pnpm lint --output-file eslint_report.json --format json", "type-check": "tsc --noEmit", "type-check:tsgo": "tsgo --noEmit", "prepare": "cd ../ && node -e \"if (process.env.NODE_ENV !== 'production'){process.exit(1)} \" || husky ./web/.husky", diff --git a/web/utils/var.ts b/web/utils/var.ts index 4f572d7768..1851084b2e 100644 --- a/web/utils/var.ts +++ b/web/utils/var.ts @@ -30,7 +30,7 @@ export const getNewVar = (key: string, type: string) => { } export const getNewVarInWorkflow = (key: string, type = InputVarType.textInput): InputVar => { - const { max_length: _maxLength, ...rest } = VAR_ITEM_TEMPLATE_IN_WORKFLOW + const { ...rest } = VAR_ITEM_TEMPLATE_IN_WORKFLOW if (type !== InputVarType.textInput) { return { ...rest,