diff --git a/api/.env.example b/api/.env.example index 44d770ed70..8099c4a42a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -589,6 +589,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true diff --git a/api/commands.py b/api/commands.py index 7ebf5b4874..e24b1826ee 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,4 +1,5 @@ import base64 +import datetime import json import logging import secrets @@ -45,6 +46,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_service import PluginService +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from tasks.remove_app_and_related_data_task import delete_draft_variables_batch logger = logging.getLogger(__name__) @@ -852,6 +854,61 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[ click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) +@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") +@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.") +@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option( + "--dry-run", + is_flag=True, + help="Preview cleanup results without deleting any workflow run data.", +) +def clean_workflow_runs( + days: int, + batch_size: int, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + dry_run: bool, +): + """ + Clean workflow runs and related workflow data for free tenants. + """ + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) + + WorkflowRunCleanup( + days=days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + dry_run=dry_run, + ).run() + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Workflow run cleanup completed. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 6a04171d2d..cf855b1cc0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1101,6 +1101,10 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable clean messages task", default=False, ) + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field( + description="Enable scheduled workflow run cleanup task", + default=False, + ) ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( description="Enable mail clean document notify task", default=False, diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index c08b62a253..bb3b13e8c6 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum): def is_ended(self) -> bool: return self in _END_STATE + @classmethod + def ended_values(cls) -> list[str]: + return [status.value for status in _END_STATE] + _END_STATE = frozenset( [ diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 2fbab001d0..08cf96c1c1 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -163,6 +163,13 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", "schedule": crontab(minute="0", hour="2"), } + if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK: + # for saas only + imports.append("schedule.clean_workflow_runs_task") + beat_schedule["clean_workflow_runs_task"] = { + "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task", + "schedule": crontab(minute="0", hour="0"), + } if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: imports.append("schedule.workflow_schedule_task") beat_schedule["workflow_schedule_task"] = { diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index daa3756dba..c32130d377 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, @@ -56,6 +57,7 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + clean_workflow_runs, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py new file mode 100644 index 0000000000..7e0cc8ec9d --- /dev/null +++ b/api/migrations/versions/2026_01_09_1630-905527cc8fd3_.py @@ -0,0 +1,30 @@ +"""add workflow_run_created_at_id_idx + +Revision ID: 905527cc8fd3 +Revises: 7df29de0f6be +Create Date: 2025-01-09 16:30:02.462084 + +""" +from alembic import op +import models as models + +# revision identifiers, used by Alembic. +revision = '905527cc8fd3' +down_revision = '7df29de0f6be' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.create_index('workflow_run_created_at_id_idx', ['created_at', 'id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index('workflow_run_created_at_id_idx') + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index a18939523b..072c6100b5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -597,6 +597,7 @@ class WorkflowRun(Base): __table_args__ = ( sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"), sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + sa.Index("workflow_run_created_at_id_idx", "created_at", "id"), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index fd547c78ba..1a2b84fdf9 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -34,11 +34,14 @@ Example: ``` """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from sqlalchemy.orm import Session + from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import WorkflowType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -253,6 +256,44 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_runs_batch_by_time_range( + self, + start_from: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + run_types: Sequence[WorkflowType] | None = None, + tenant_ids: Sequence[str] | None = None, + ) -> Sequence[WorkflowRun]: + """ + Fetch ended workflow runs in a time window for archival and clean batching. + """ + ... + + def delete_runs_with_related( + self, + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + """ + Delete workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons). + """ + ... + + def count_runs_with_related( + self, + runs: Sequence[WorkflowRun], + count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + """ + Count workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons) without deleting data. + """ + ... + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 7e2173acdd..2de3a15d65 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -7,13 +7,18 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. from collections.abc import Sequence from datetime import datetime -from typing import cast +from typing import TypedDict, cast -from sqlalchemy import asc, delete, desc, select +from sqlalchemy import asc, delete, desc, func, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from models.workflow import WorkflowNodeExecutionModel +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import ( + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowNodeExecutionTriggeredFrom, +) from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -44,6 +49,26 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut """ self._session_maker = session_maker + @staticmethod + def _map_run_triggered_from_to_node_triggered_from(triggered_from: str) -> str: + """ + Map workflow run triggered_from values to workflow node execution triggered_from values. + """ + if triggered_from in { + WorkflowRunTriggeredFrom.APP_RUN.value, + WorkflowRunTriggeredFrom.DEBUGGING.value, + WorkflowRunTriggeredFrom.SCHEDULE.value, + WorkflowRunTriggeredFrom.PLUGIN.value, + WorkflowRunTriggeredFrom.WEBHOOK.value, + }: + return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + if triggered_from in { + WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value, + WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value, + }: + return WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value + return "" + def get_node_last_execution( self, tenant_id: str, @@ -290,3 +315,119 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut result = cast(CursorResult, session.execute(stmt)) session.commit() return result.rowcount + + class RunContext(TypedDict): + run_id: str + tenant_id: str + app_id: str + workflow_id: str + triggered_from: str + + @staticmethod + def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]: + """ + Delete node executions (and offloads) for the given workflow runs using indexed columns. + + Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id) + by filtering on those columns with tuple IN. + """ + if not runs: + return 0, 0 + + tuple_values = [ + ( + run["tenant_id"], + run["app_id"], + run["workflow_id"], + DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from( + run["triggered_from"] + ), + run["run_id"], + ) + for run in runs + ] + + node_execution_ids = session.scalars( + select(WorkflowNodeExecutionModel.id).where( + tuple_( + WorkflowNodeExecutionModel.tenant_id, + WorkflowNodeExecutionModel.app_id, + WorkflowNodeExecutionModel.workflow_id, + WorkflowNodeExecutionModel.triggered_from, + WorkflowNodeExecutionModel.workflow_run_id, + ).in_(tuple_values) + ) + ).all() + + if not node_execution_ids: + return 0, 0 + + offloads_deleted = ( + cast( + CursorResult, + session.execute( + delete(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + ), + ).rowcount + or 0 + ) + + node_executions_deleted = ( + cast( + CursorResult, + session.execute( + delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) + ), + ).rowcount + or 0 + ) + + return node_executions_deleted, offloads_deleted + + @staticmethod + def count_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]: + """ + Count node executions (and offloads) for the given workflow runs using indexed columns. + """ + if not runs: + return 0, 0 + + tuple_values = [ + ( + run["tenant_id"], + run["app_id"], + run["workflow_id"], + DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from( + run["triggered_from"] + ), + run["run_id"], + ) + for run in runs + ] + tuple_filter = tuple_( + WorkflowNodeExecutionModel.tenant_id, + WorkflowNodeExecutionModel.app_id, + WorkflowNodeExecutionModel.workflow_id, + WorkflowNodeExecutionModel.triggered_from, + WorkflowNodeExecutionModel.workflow_run_id, + ).in_(tuple_values) + + node_executions_count = ( + session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(tuple_filter)) or 0 + ) + offloads_count = ( + session.scalar( + select(func.count()) + .select_from(WorkflowNodeExecutionOffload) + .join( + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id, + ) + .where(tuple_filter) + ) + or 0 + ) + + return int(node_executions_count), int(offloads_count) diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index b172c6a3ac..9d2d06e99f 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -21,7 +21,7 @@ Implementation Notes: import logging import uuid -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from decimal import Decimal from typing import Any, cast @@ -32,7 +32,7 @@ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause -from core.workflow.enums import WorkflowExecutionStatus +from core.workflow.enums import WorkflowExecutionStatus, WorkflowType from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date @@ -40,8 +40,14 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun +from models.workflow import ( + WorkflowAppLog, + WorkflowPauseReason, + WorkflowRun, +) +from models.workflow import ( + WorkflowPause as WorkflowPauseModel, +) from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( @@ -314,6 +320,171 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted + def get_runs_batch_by_time_range( + self, + start_from: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + run_types: Sequence[WorkflowType] | None = None, + tenant_ids: Sequence[str] | None = None, + ) -> Sequence[WorkflowRun]: + """ + Fetch ended workflow runs in a time window for archival and clean batching. + + Query scope: + - created_at in [start_from, end_before) + - type in run_types (when provided) + - status is an ended state + - optional tenant_id filter and cursor (last_seen) for pagination + """ + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.created_at < end_before, + WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()), + ) + .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) + .limit(batch_size) + ) + if run_types is not None: + if not run_types: + return [] + stmt = stmt.where(WorkflowRun.type.in_(run_types)) + + if start_from: + stmt = stmt.where(WorkflowRun.created_at >= start_from) + + if tenant_ids: + stmt = stmt.where(WorkflowRun.tenant_id.in_(tenant_ids)) + + if last_seen: + stmt = stmt.where( + or_( + WorkflowRun.created_at > last_seen[0], + and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]), + ) + ) + + return session.scalars(stmt).all() + + def delete_runs_with_related( + self, + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + if not runs: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + run_ids = [run.id for run in runs] + if delete_node_executions: + node_executions_deleted, offloads_deleted = delete_node_executions(session, runs) + else: + node_executions_deleted, offloads_deleted = 0, 0 + + app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) + app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 + + pause_ids = session.scalars( + select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + ).all() + pause_reasons_deleted = 0 + pauses_deleted = 0 + + if pause_ids: + pause_reasons_result = session.execute( + delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + ) + pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0 + pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids))) + pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 + + trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0 + + runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))) + runs_deleted = cast(CursorResult, runs_result).rowcount or 0 + + session.commit() + + return { + "runs": runs_deleted, + "node_executions": node_executions_deleted, + "offloads": offloads_deleted, + "app_logs": app_logs_deleted, + "trigger_logs": trigger_logs_deleted, + "pauses": pauses_deleted, + "pause_reasons": pause_reasons_deleted, + } + + def count_runs_with_related( + self, + runs: Sequence[WorkflowRun], + count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, + count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: + if not runs: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + run_ids = [run.id for run in runs] + if count_node_executions: + node_executions_count, offloads_count = count_node_executions(session, runs) + else: + node_executions_count, offloads_count = 0, 0 + + app_logs_count = ( + session.scalar( + select(func.count()).select_from(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids)) + ) + or 0 + ) + + pause_ids = session.scalars( + select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + ).all() + pauses_count = len(pause_ids) + pause_reasons_count = 0 + if pause_ids: + pause_reasons_count = ( + session.scalar( + select(func.count()) + .select_from(WorkflowPauseReason) + .where(WorkflowPauseReason.pause_id.in_(pause_ids)) + ) + or 0 + ) + + trigger_logs_count = count_trigger_logs(session, run_ids) if count_trigger_logs else 0 + + return { + "runs": len(runs), + "node_executions": node_executions_count, + "offloads": offloads_count, + "app_logs": int(app_logs_count), + "trigger_logs": trigger_logs_count, + "pauses": pauses_count, + "pause_reasons": int(pause_reasons_count), + } + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index 0d67e286b0..ebd3745d18 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository. from collections.abc import Sequence from datetime import UTC, datetime, timedelta +from typing import cast -from sqlalchemy import and_, select +from sqlalchemy import and_, delete, func, select +from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from models.enums import WorkflowTriggerStatus @@ -84,3 +86,37 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): ) return list(self.session.scalars(query).all()) + + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Delete trigger logs associated with the given workflow run ids. + + Args: + run_ids: Collection of workflow run identifiers. + + Returns: + Number of rows deleted. + """ + if not run_ids: + return 0 + + result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))) + return cast(CursorResult, result).rowcount or 0 + + def count_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Count trigger logs associated with the given workflow run ids. + + Args: + run_ids: Collection of workflow run identifiers. + + Returns: + Number of rows matched. + """ + if not run_ids: + return 0 + + count = self.session.scalar( + select(func.count()).select_from(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)) + ) + return int(count or 0) diff --git a/api/repositories/workflow_trigger_log_repository.py b/api/repositories/workflow_trigger_log_repository.py index 138b8779ac..b0009e398d 100644 --- a/api/repositories/workflow_trigger_log_repository.py +++ b/api/repositories/workflow_trigger_log_repository.py @@ -109,3 +109,15 @@ class WorkflowTriggerLogRepository(Protocol): A sequence of recent WorkflowTriggerLog instances """ ... + + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Delete trigger logs for workflow run IDs. + + Args: + run_ids: Workflow run IDs to delete + + Returns: + Number of rows deleted + """ + ... diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py new file mode 100644 index 0000000000..9f5bf8e150 --- /dev/null +++ b/api/schedule/clean_workflow_runs_task.py @@ -0,0 +1,43 @@ +from datetime import UTC, datetime + +import click + +import app +from configs import dify_config +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +@app.celery.task(queue="retention") +def clean_workflow_runs_task() -> None: + """ + Scheduled cleanup for workflow runs and related records (sandbox tenants only). + """ + click.echo( + click.style( + ( + "Scheduled workflow run cleanup starting: " + f"cutoff={dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS} days, " + f"batch={dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE}" + ), + fg="green", + ) + ) + + start_time = datetime.now(UTC) + + WorkflowRunCleanup( + days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS, + batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE, + start_from=None, + end_before=None, + ).run() + + end_time = datetime.now(UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) diff --git a/api/services/retention/__init__.py b/api/services/retention/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/retention/workflow_run/__init__.py b/api/services/retention/workflow_run/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..2213169510 --- /dev/null +++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,301 @@ +import datetime +import logging +from collections.abc import Iterable, Sequence + +import click +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository +from services.billing_service import BillingService, SubscriptionPlan + +logger = logging.getLogger(__name__) + + +class WorkflowRunCleanup: + def __init__( + self, + days: int, + batch_size: int, + start_from: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + workflow_run_repo: APIWorkflowRunRepository | None = None, + dry_run: bool = False, + ): + if (start_from is None) ^ (end_before is None): + raise ValueError("start_from and end_before must be both set or both omitted.") + + computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days) + self.window_start = start_from + self.window_end = end_before or computed_cutoff + + if self.window_start and self.window_end <= self.window_start: + raise ValueError("end_before must be greater than start_from.") + + if batch_size <= 0: + raise ValueError("batch_size must be greater than 0.") + + self.batch_size = batch_size + self._cleanup_whitelist: set[str] | None = None + self.dry_run = dry_run + self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD + self.workflow_run_repo: APIWorkflowRunRepository + if workflow_run_repo: + self.workflow_run_repo = workflow_run_repo + else: + # Lazy import to avoid circular dependencies during module import + from repositories.factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + def run(self) -> None: + click.echo( + click.style( + f"{'Inspecting' if self.dry_run else 'Cleaning'} workflow runs " + f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}" + f"{self.window_end.isoformat()} (batch={self.batch_size})", + fg="white", + ) + ) + if self.dry_run: + click.echo(click.style("Dry run mode enabled. No data will be deleted.", fg="yellow")) + + total_runs_deleted = 0 + total_runs_targeted = 0 + related_totals = self._empty_related_counts() if self.dry_run else None + batch_index = 0 + last_seen: tuple[datetime.datetime, str] | None = None + + while True: + run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( + start_from=self.window_start, + end_before=self.window_end, + last_seen=last_seen, + batch_size=self.batch_size, + ) + if not run_rows: + break + + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) + tenant_ids = {row.tenant_id for row in run_rows} + free_tenants = self._filter_free_tenants(tenant_ids) + free_runs = [row for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_runs) + + if not free_runs: + click.echo( + click.style( + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)", + fg="yellow", + ) + ) + continue + + total_runs_targeted += len(free_runs) + + if self.dry_run: + batch_counts = self.workflow_run_repo.count_runs_with_related( + free_runs, + count_node_executions=self._count_node_executions, + count_trigger_logs=self._count_trigger_logs, + ) + if related_totals is not None: + for key in related_totals: + related_totals[key] += batch_counts.get(key, 0) + sample_ids = ", ".join(run.id for run in free_runs[:5]) + click.echo( + click.style( + f"[batch #{batch_index}] would delete {len(free_runs)} runs " + f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown", + fg="yellow", + ) + ) + continue + + try: + counts = self.workflow_run_repo.delete_runs_with_related( + free_runs, + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + except Exception: + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] + click.echo( + click.style( + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", + ) + ) + + if self.dry_run: + if self.window_start: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"before {self.window_end.isoformat()}" + ) + if related_totals is not None: + summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}" + summary_color = "yellow" + else: + if self.window_start: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}" + ) + summary_color = "white" + + click.echo(click.style(summary_message, fg=summary_color)) + + def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: + tenant_id_list = list(tenant_ids) + + if not dify_config.BILLING_ENABLED: + return set(tenant_id_list) + + if not tenant_id_list: + return set() + + cleanup_whitelist = self._get_cleanup_whitelist() + + try: + bulk_info = BillingService.get_plan_bulk_with_cache(tenant_id_list) + except Exception: + bulk_info = {} + logger.exception("Failed to fetch billing plans in bulk for tenants: %s", tenant_id_list) + + eligible_free_tenants: set[str] = set() + for tenant_id in tenant_id_list: + if tenant_id in cleanup_whitelist: + continue + + info = bulk_info.get(tenant_id) + if info is None: + logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id) + continue + + if info.get("plan") != CloudPlan.SANDBOX: + continue + + if self._is_within_grace_period(tenant_id, info): + continue + + eligible_free_tenants.add(tenant_id) + + return eligible_free_tenants + + def _expiration_datetime(self, tenant_id: str, expiration_value: int) -> datetime.datetime | None: + if expiration_value < 0: + return None + + try: + return datetime.datetime.fromtimestamp(expiration_value, datetime.UTC) + except (OverflowError, OSError, ValueError): + logger.exception("Failed to parse expiration timestamp for tenant %s", tenant_id) + return None + + def _is_within_grace_period(self, tenant_id: str, info: SubscriptionPlan) -> bool: + if self.free_plan_grace_period_days <= 0: + return False + + expiration_value = info.get("expiration_date", -1) + expiration_at = self._expiration_datetime(tenant_id, expiration_value) + if expiration_at is None: + return False + + grace_deadline = expiration_at + datetime.timedelta(days=self.free_plan_grace_period_days) + return datetime.datetime.now(datetime.UTC) < grace_deadline + + def _get_cleanup_whitelist(self) -> set[str]: + if self._cleanup_whitelist is not None: + return self._cleanup_whitelist + + if not dify_config.BILLING_ENABLED: + self._cleanup_whitelist = set() + return self._cleanup_whitelist + + try: + whitelist_ids = BillingService.get_expired_subscription_cleanup_whitelist() + except Exception: + logger.exception("Failed to fetch cleanup whitelist from billing service") + whitelist_ids = [] + + self._cleanup_whitelist = set(whitelist_ids) + return self._cleanup_whitelist + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) + + def _count_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.count_by_run_ids(run_ids) + + @staticmethod + def _build_run_contexts( + runs: Sequence[WorkflowRun], + ) -> list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext]: + return [ + { + "run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "triggered_from": run.triggered_from, + } + for run in runs + ] + + @staticmethod + def _empty_related_counts() -> dict[str, int]: + return { + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + @staticmethod + def _format_related_counts(counts: dict[str, int]) -> str: + return ( + f"node_executions {counts['node_executions']}, " + f"offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, " + f"trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, " + f"pause_reasons {counts['pause_reasons']}" + ) + + def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_contexts = self._build_run_contexts(runs) + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.count_by_runs(session, run_contexts) + + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_contexts = self._build_run_contexts(runs) + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 0c34676252..d443c4c9a5 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -4,6 +4,7 @@ from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker from core.workflow.enums import WorkflowExecutionStatus @@ -104,6 +105,42 @@ class TestDifyAPISQLAlchemyWorkflowRunRepository: return pause +class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_get_runs_batch_by_time_range_filters_terminal_statuses( + self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock + ): + scalar_result = Mock() + scalar_result.all.return_value = [] + mock_session.scalars.return_value = scalar_result + + repository.get_runs_batch_by_time_range( + start_from=None, + end_before=datetime(2024, 1, 1), + last_seen=None, + batch_size=50, + ) + + stmt = mock_session.scalars.call_args[0][0] + compiled_sql = str( + stmt.compile( + dialect=postgresql.dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + + assert "workflow_runs.status" in compiled_sql + for status in ( + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.STOPPED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + ): + assert f"'{status.value}'" in compiled_sql + + assert "'running'" not in compiled_sql + assert "'paused'" not in compiled_sql + + class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test create_workflow_pause method.""" @@ -181,6 +218,61 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ) +class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): + node_ids_result = Mock() + node_ids_result.all.return_value = [] + pause_ids_result = Mock() + pause_ids_result.all.return_value = [] + mock_session.scalars.side_effect = [node_ids_result, pause_ids_result] + + # app_logs delete, runs delete + mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)] + + fake_trigger_repo = Mock() + fake_trigger_repo.delete_by_run_ids.return_value = 3 + + run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") + counts = repository.delete_runs_with_related( + [run], + delete_node_executions=lambda session, runs: (2, 1), + delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids), + ) + + fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["runs"] == 1 + + +class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): + pause_ids_result = Mock() + pause_ids_result.all.return_value = ["pause-1", "pause-2"] + mock_session.scalars.return_value = pause_ids_result + mock_session.scalar.side_effect = [5, 2] + + fake_trigger_repo = Mock() + fake_trigger_repo.count_by_run_ids.return_value = 3 + + run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") + counts = repository.count_runs_with_related( + [run], + count_node_executions=lambda session, runs: (2, 1), + count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids), + ) + + fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["app_logs"] == 5 + assert counts["pauses"] == 2 + assert counts["pause_reasons"] == 2 + assert counts["runs"] == 1 + + class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test resume_workflow_pause method.""" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py new file mode 100644 index 0000000000..d409618211 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py @@ -0,0 +1,31 @@ +from unittest.mock import Mock + +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import Session + +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +def test_delete_by_run_ids_executes_delete(): + session = Mock(spec=Session) + session.execute.return_value = Mock(rowcount=2) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids(["run-1", "run-2"]) + + stmt = session.execute.call_args[0][0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) + assert "workflow_trigger_logs" in compiled_sql + assert "'run-1'" in compiled_sql + assert "'run-2'" in compiled_sql + assert deleted == 2 + + +def test_delete_by_run_ids_empty_short_circuits(): + session = Mock(spec=Session) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids([]) + + session.execute.assert_not_called() + assert deleted == 0 diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..8c80e2b4ad --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,327 @@ +import datetime +from typing import Any + +import pytest + +from services.billing_service import SubscriptionPlan +from services.retention.workflow_run import clear_free_plan_expired_workflow_run_logs as cleanup_module +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +class FakeRun: + def __init__( + self, + run_id: str, + tenant_id: str, + created_at: datetime.datetime, + app_id: str = "app-1", + workflow_id: str = "wf-1", + triggered_from: str = "workflow-run", + ) -> None: + self.id = run_id + self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.triggered_from = triggered_from + self.created_at = created_at + + +class FakeRepo: + def __init__( + self, + batches: list[list[FakeRun]], + delete_result: dict[str, int] | None = None, + count_result: dict[str, int] | None = None, + ) -> None: + self.batches = batches + self.call_idx = 0 + self.deleted: list[list[str]] = [] + self.counted: list[list[str]] = [] + self.delete_result = delete_result or { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + self.count_result = count_result or { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + def get_runs_batch_by_time_range( + self, + start_from: datetime.datetime | None, + end_before: datetime.datetime, + last_seen: tuple[datetime.datetime, str] | None, + batch_size: int, + ) -> list[FakeRun]: + if self.call_idx >= len(self.batches): + return [] + batch = self.batches[self.call_idx] + self.call_idx += 1 + return batch + + def delete_runs_with_related( + self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None + ) -> dict[str, int]: + self.deleted.append([run.id for run in runs]) + result = self.delete_result.copy() + result["runs"] = len(runs) + return result + + def count_runs_with_related( + self, runs: list[FakeRun], count_node_executions=None, count_trigger_logs=None + ) -> dict[str, int]: + self.counted.append([run.id for run in runs]) + result = self.count_result.copy() + result["runs"] = len(runs) + return result + + +def plan_info(plan: str, expiration: int) -> SubscriptionPlan: + return SubscriptionPlan(plan=plan, expiration_date=expiration) + + +def create_cleanup( + monkeypatch: pytest.MonkeyPatch, + repo: FakeRepo, + *, + grace_period_days: int = 0, + whitelist: set[str] | None = None, + **kwargs: Any, +) -> WorkflowRunCleanup: + monkeypatch.setattr( + cleanup_module.dify_config, + "SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD", + grace_period_days, + ) + monkeypatch.setattr( + cleanup_module.WorkflowRunCleanup, + "_get_cleanup_whitelist", + lambda self: whitelist or set(), + ) + return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs) + + +def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + def fail_bulk(_: list[str]) -> dict[str, SubscriptionPlan]: + raise RuntimeError("should not call") + + monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fail_bulk)) + + tenants = {"t1", "t2"} + free = cleanup._filter_free_tenants(tenants) + + assert free == tenants + + +def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"}) + + assert free == {"t_free", "t_missing"} + + +def test_filter_free_tenants_respects_grace_period(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, grace_period_days=45) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + now = datetime.datetime.now(datetime.UTC) + within_grace_ts = int((now - datetime.timedelta(days=10)).timestamp()) + outside_grace_ts = int((now - datetime.timedelta(days=90)).timestamp()) + + def fake_bulk(_: list[str]) -> dict[str, SubscriptionPlan]: + return { + "recently_downgraded": plan_info("sandbox", within_grace_ts), + "long_sandbox": plan_info("sandbox", outside_grace_ts), + } + + monkeypatch.setattr(cleanup_module.BillingService, "get_plan_bulk_with_cache", staticmethod(fake_bulk)) + + free = cleanup._filter_free_tenants({"recently_downgraded", "long_sandbox"}) + + assert free == {"long_sandbox"} + + +def test_filter_free_tenants_skips_cleanup_whitelist(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup( + monkeypatch, + repo=FakeRepo([]), + days=30, + batch_size=10, + whitelist={"tenant_whitelist"}, + ) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + tenants = {"tenant_whitelist", "tenant_regular"} + free = cleanup._filter_free_tenants(tenants) + + assert free == {"tenant_regular"} + + +def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod(lambda tenant_ids: (_ for _ in ()).throw(RuntimeError("boom"))), + ) + + free = cleanup._filter_free_tenants({"t1", "t2"}) + + assert free == set() + + +def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[ + [ + FakeRun("run-free", "t_free", cutoff), + FakeRun("run-paid", "t_paid", cutoff), + ] + ] + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod( + lambda tenant_ids: { + tenant_id: (plan_info("team", -1) if tenant_id == "t_paid" else plan_info("sandbox", -1)) + for tenant_id in tenant_ids + } + ), + ) + + cleanup.run() + + assert repo.deleted == [["run-free"]] + + +def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]]) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_plan_bulk_with_cache", + staticmethod(lambda tenant_ids: {tenant_id: plan_info("team", 1893456000) for tenant_id in tenant_ids}), + ) + + cleanup.run() + + assert repo.deleted == [] + + +def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) + + cleanup.run() + + +def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[[FakeRun("run-free", "t_free", cutoff)]], + count_result={ + "runs": 0, + "node_executions": 2, + "offloads": 1, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + }, + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10, dry_run=True) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + cleanup.run() + + assert repo.deleted == [] + assert repo.counted == [["run-free"]] + captured = capsys.readouterr().out + assert "Dry run mode enabled" in captured + assert "would delete 1 runs" in captured + assert "related records" in captured + assert "node_executions 2" in captured + assert "offloads 1" in captured + assert "app_logs 3" in captured + assert "trigger_logs 4" in captured + assert "pauses 5" in captured + assert "pause_reasons 6" in captured + + +def test_between_sets_window_bounds(monkeypatch: pytest.MonkeyPatch) -> None: + start_from = datetime.datetime(2024, 5, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 6, 1, 0, 0, 0) + cleanup = create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before + ) + + assert cleanup.window_start == start_from + assert cleanup.window_end == end_before + + +def test_between_requires_both_boundaries(monkeypatch: pytest.MonkeyPatch) -> None: + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=datetime.datetime.now(), end_before=None + ) + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=None, end_before=datetime.datetime.now() + ) + + +def test_between_requires_end_after_start(monkeypatch: pytest.MonkeyPatch) -> None: + start_from = datetime.datetime(2024, 6, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 5, 1, 0, 0, 0) + with pytest.raises(ValueError): + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_from=start_from, end_before=end_before + ) diff --git a/docker/.env.example b/docker/.env.example index 09ee1060e2..e7cb8711ce 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1478,6 +1478,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 712de84c62..041f60aaa2 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -662,6 +662,7 @@ x-shared-env: &shared-api-worker-env ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false} ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-false} ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-false} + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: ${ENABLE_WORKFLOW_RUN_CLEANUP_TASK:-false} ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false} ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false} ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true}