diff --git a/api/commands.py b/api/commands.py index a8d89ac200..cdf3a09bb7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,4 +1,5 @@ import base64 +import datetime import json import logging import secrets @@ -41,6 +42,7 @@ from models.provider_ids import DatasourceProviderID, ToolProviderID from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService +from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration @@ -852,6 +854,45 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[ click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) +@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") +@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.") +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting workflow runs.") +@click.option( + "--start-after", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-after.", +) +def clean_workflow_runs( + days: int, + batch_size: int, + start_after: datetime.datetime | None, + end_before: datetime.datetime | None, +): + """ + Clean workflow runs and related workflow data for free tenants. + """ + if (start_after is None) ^ (end_before is None): + raise click.UsageError("--start-after and --end-before must be provided together.") + + click.echo(click.style("Starting workflow run cleanup.", fg="white")) + + WorkflowRunCleanup( + days=days, + batch_size=batch_size, + start_after=start_after, + end_before=end_before, + ).run() + + click.echo(click.style("Workflow run cleanup completed.", fg="green")) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index a5916241df..e237470709 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1116,6 +1116,11 @@ class CeleryScheduleTasksConfig(BaseSettings): default=60 * 60, ) + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field( + description="Enable scheduled workflow run cleanup task", + default=False, + ) + class PositionConfig(BaseSettings): POSITION_PROVIDER_PINS: str = Field( diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 5cf4984709..c19763372d 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -160,6 +160,13 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", "schedule": crontab(minute="0", hour="2"), } + if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK: + # for saas only + imports.append("schedule.clean_workflow_runs_task") + beat_schedule["clean_workflow_runs_task"] = { + "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task", + "schedule": crontab(minute="0", hour="0"), + } if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: imports.append("schedule.workflow_schedule_task") beat_schedule["workflow_schedule_task"] = { diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 71a63168a5..6f6322827c 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, @@ -54,6 +55,7 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + clean_workflow_runs, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2025_12_10_1504-8a7f2ad7c23e_add_workflow_runs_created_at_idx.py b/api/migrations/versions/2025_12_10_1504-8a7f2ad7c23e_add_workflow_runs_created_at_idx.py new file mode 100644 index 0000000000..7968429ca8 --- /dev/null +++ b/api/migrations/versions/2025_12_10_1504-8a7f2ad7c23e_add_workflow_runs_created_at_idx.py @@ -0,0 +1,29 @@ +"""Add index on workflow_runs.created_at + +Revision ID: 8a7f2ad7c23e +Revises: d57accd375ae +Create Date: 2025-12-10 15:04:00.000000 +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "8a7f2ad7c23e" +down_revision = "d57accd375ae" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("workflow_runs", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("workflow_runs_created_at_idx"), + ["created_at"], + unique=False, + ) + + +def downgrade(): + with op.batch_alter_table("workflow_runs", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("workflow_runs_created_at_idx")) diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py new file mode 100644 index 0000000000..b59dc7f823 --- /dev/null +++ b/api/schedule/clean_workflow_runs_task.py @@ -0,0 +1,30 @@ +import click + +import app +from configs import dify_config +from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + +CLEANUP_QUEUE = "retention" + + +@app.celery.task(queue=CLEANUP_QUEUE) +def clean_workflow_runs_task() -> None: + """ + Scheduled cleanup for workflow runs and related records (sandbox tenants only). + """ + click.echo( + click.style( + f"Scheduled workflow run cleanup starting: cutoff={dify_config.WORKFLOW_LOG_RETENTION_DAYS} days, " + f"batch={dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE}", + fg="green", + ) + ) + + WorkflowRunCleanup( + days=dify_config.WORKFLOW_LOG_RETENTION_DAYS, + batch_size=dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE, + start_after=None, + end_before=None, + ).run() + + click.echo(click.style("Scheduled workflow run cleanup finished.", fg="green")) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54e1c9d285..4fc61c6c0f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,3 +1,4 @@ +import logging import os from typing import Literal @@ -11,6 +12,8 @@ from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole +logger = logging.getLogger(__name__) + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") @@ -25,6 +28,25 @@ class BillingService: billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info + @classmethod + def get_info_bulk(cls, tenant_ids: list[str]) -> dict[str, dict]: + """ + Temporary bulk billing info fetch. Will be replaced by a real batch API. + + Args: + tenant_ids: list of tenant ids + + Returns: + Mapping of tenant_id -> billing info dict + """ + result: dict[str, dict] = {} + for tenant_id in tenant_ids: + try: + result[tenant_id] = cls.get_info(tenant_id) + except Exception: + logger.exception("Failed to fetch billing info for tenant %s in bulk mode", tenant_id) + return result + @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} diff --git a/api/services/clear_free_plan_expired_workflow_run_logs.py b/api/services/clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..98543c2928 --- /dev/null +++ b/api/services/clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,235 @@ +import datetime +import logging +from collections.abc import Iterable, Sequence +from dataclasses import dataclass + +import click +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from models import WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun +from models.trigger import WorkflowTriggerLog +from models.workflow import WorkflowNodeExecutionOffload, WorkflowPause, WorkflowPauseReason +from services.billing_service import BillingService + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class WorkflowRunRow: + id: str + tenant_id: str + created_at: datetime.datetime + + +class WorkflowRunCleanup: + def __init__( + self, + days: int, + batch_size: int, + start_after: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + ): + if (start_after is None) ^ (end_before is None): + raise ValueError("start_after and end_before must be both set or both omitted.") + + computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days) + self.window_start = start_after + self.window_end = end_before or computed_cutoff + + if self.window_start and self.window_end <= self.window_start: + raise ValueError("end_before must be greater than start_after.") + + self.batch_size = batch_size + self.billing_cache: dict[str, CloudPlan | None] = {} + + def run(self) -> None: + click.echo( + click.style( + f"Cleaning workflow runs " + f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}" + f"{self.window_end.isoformat()} (batch={self.batch_size})", + fg="white", + ) + ) + + total_runs_deleted = 0 + batch_index = 0 + last_seen: tuple[datetime.datetime, str] | None = None + + while True: + with Session(db.engine) as session: + run_rows = self._load_batch(session, last_seen) + if not run_rows: + break + + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) + tenant_ids = {row.tenant_id for row in run_rows} + free_tenants = self._filter_free_tenants(tenant_ids) + free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_run_ids) + + if not free_run_ids: + click.echo( + click.style( + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid)", + fg="yellow", + ) + ) + continue + + try: + counts = self._delete_runs(session, free_run_ids) + session.commit() + except Exception: + session.rollback() + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] + click.echo( + click.style( + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", + ) + ) + + if self.window_start: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}" + ) + + click.echo(click.style(summary_message, fg="white")) + + def _load_batch( + self, session: Session, last_seen: tuple[datetime.datetime, str] | None + ) -> list[WorkflowRunRow]: + stmt = ( + select(WorkflowRun.id, WorkflowRun.tenant_id, WorkflowRun.created_at) + .where(WorkflowRun.created_at < self.window_end) + .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) + .limit(self.batch_size) + ) + + if self.window_start: + stmt = stmt.where(WorkflowRun.created_at >= self.window_start) + + if last_seen: + stmt = stmt.where( + sa.or_( + WorkflowRun.created_at > last_seen[0], + sa.and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]), + ) + ) + + rows = session.execute(stmt).all() + return [WorkflowRunRow(id=row.id, tenant_id=row.tenant_id, created_at=row.created_at) for row in rows] + + def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: + if not dify_config.BILLING_ENABLED: + return set(tenant_ids) + + tenant_id_list = list(tenant_ids) + uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache] + + if uncached_tenants: + try: + bulk_info = BillingService.get_info_bulk(uncached_tenants) + except Exception: + bulk_info = {} + logger.exception("Failed to fetch billing plans in bulk for tenants: %s", uncached_tenants) + + for tenant_id in uncached_tenants: + plan: CloudPlan | None = None + info = bulk_info.get(tenant_id) + if info: + try: + raw_plan = info.get("subscription", {}).get("plan") + plan = CloudPlan(raw_plan) + except Exception: + logger.exception("Failed to parse billing plan for tenant %s", tenant_id) + else: + logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id) + + self.billing_cache[tenant_id] = plan + + return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX} + + def _delete_runs(self, session: Session, workflow_run_ids: Sequence[str]) -> dict[str, int]: + node_execution_ids = session.scalars( + select(WorkflowNodeExecutionModel.id).where(WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)) + ).all() + + offloads_deleted = 0 + if node_execution_ids: + offloads_deleted = ( + session.query(WorkflowNodeExecutionOffload) + .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)) + .delete(synchronize_session=False) + ) + + node_executions_deleted = 0 + if node_execution_ids: + node_executions_deleted = ( + session.query(WorkflowNodeExecutionModel) + .where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) + .delete(synchronize_session=False) + ) + + app_logs_deleted = ( + session.query(WorkflowAppLog) + .where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)) + .delete(synchronize_session=False) + ) + + pause_ids = session.scalars( + select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(workflow_run_ids)) + ).all() + pause_reasons_deleted = 0 + pauses_deleted = 0 + + if pause_ids: + pause_reasons_deleted = ( + session.query(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)).delete( + synchronize_session=False + ) + ) + pauses_deleted = ( + session.query(WorkflowPause) + .where(WorkflowPause.id.in_(pause_ids)) + .delete(synchronize_session=False) + ) + + trigger_logs_deleted = ( + session.query(WorkflowTriggerLog) + .where(WorkflowTriggerLog.workflow_run_id.in_(workflow_run_ids)) + .delete(synchronize_session=False) + ) + + runs_deleted = ( + session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) + ) + + return { + "runs": runs_deleted, + "node_executions": node_executions_deleted, + "offloads": offloads_deleted, + "app_logs": app_logs_deleted, + "trigger_logs": trigger_logs_deleted, + "pauses": pauses_deleted, + "pause_reasons": pause_reasons_deleted, + } diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..fd756bcddd --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,231 @@ +import datetime +from typing import Any + +import pytest + +from services import clear_free_plan_expired_workflow_run_logs as cleanup_module +from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup, WorkflowRunRow + + +class DummySession: + def __init__(self) -> None: + self.committed = False + + def __enter__(self) -> "DummySession": + return self + + def __exit__(self, exc_type: object, exc: object, tb: object) -> None: + return None + + def commit(self) -> None: + self.committed = True + + +def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + def fail_bulk(_: list[str]) -> dict[str, dict[str, Any]]: + raise RuntimeError("should not call") + + monkeypatch.setattr(cleanup_module.BillingService, "get_info_bulk", staticmethod(fail_bulk)) + + tenants = {"t1", "t2"} + free = cleanup._filter_free_tenants(tenants) + + assert free == tenants + + +def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + # seed cache to avoid relying on billing service implementation + cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX + cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod(lambda tenant_ids: {tenant_id: {} for tenant_id in tenant_ids}), + ) + + free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"}) + + assert free == {"t_free"} + + +def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod(lambda tenant_ids: (_ for _ in ()).throw(RuntimeError("boom"))), + ) + + free = cleanup._filter_free_tenants({"t1", "t2"}) + + assert free == set() + + +def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + cleanup = WorkflowRunCleanup(days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX + cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod(lambda tenant_ids: {tenant_id: {} for tenant_id in tenant_ids}), + ) + + batches_returned = 0 + + def fake_load_batch( + session: DummySession, last_seen: tuple[datetime.datetime, str] | None + ) -> list[WorkflowRunRow]: + nonlocal batches_returned + if batches_returned > 0: + return [] + batches_returned += 1 + return [ + WorkflowRunRow(id="run-free", tenant_id="t_free", created_at=cutoff), + WorkflowRunRow(id="run-paid", tenant_id="t_paid", created_at=cutoff), + ] + + deleted_ids: list[list[str]] = [] + + def fake_delete_runs(session: DummySession, workflow_run_ids: list[str]) -> dict[str, int]: + deleted_ids.append(list(workflow_run_ids)) + return { + "runs": len(workflow_run_ids), + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + created_sessions: list[DummySession] = [] + + def fake_session_factory(engine: object | None = None) -> DummySession: + session = DummySession() + created_sessions.append(session) + return session + + monkeypatch.setattr(cleanup, "_load_batch", fake_load_batch) + monkeypatch.setattr(cleanup, "_delete_runs", fake_delete_runs) + monkeypatch.setattr(cleanup_module, "Session", fake_session_factory) + + class DummyDB: + engine: object | None = None + + monkeypatch.setattr(cleanup_module, "db", DummyDB()) + + cleanup.run() + + assert deleted_ids == [["run-free"]] + assert created_sessions + assert created_sessions[0].committed is True + + +def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + cleanup = WorkflowRunCleanup(days=30, batch_size=10) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod( + lambda tenant_ids: {tenant_id: {"subscription": {"plan": "TEAM"}} for tenant_id in tenant_ids} + ), + ) + + batches_returned = 0 + + def fake_load_batch( + session: DummySession, last_seen: tuple[datetime.datetime, str] | None + ) -> list[WorkflowRunRow]: + nonlocal batches_returned + if batches_returned > 0: + return [] + batches_returned += 1 + return [WorkflowRunRow(id="run-paid", tenant_id="t_paid", created_at=cutoff)] + + delete_called = False + + def fake_delete_runs(session: DummySession, workflow_run_ids: list[str]) -> dict[str, int]: + nonlocal delete_called + delete_called = True + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0 + } + + def fake_session_factory(engine: object | None = None) -> DummySession: # pragma: no cover - simple factory + return DummySession() + + monkeypatch.setattr(cleanup, "_load_batch", fake_load_batch) + monkeypatch.setattr(cleanup, "_delete_runs", fake_delete_runs) + monkeypatch.setattr(cleanup_module, "Session", fake_session_factory) + monkeypatch.setattr(cleanup_module, "db", type("DummyDB", (), {"engine": None})) + + cleanup.run() + + assert delete_called is False + + +def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10) + + def fake_load_batch( + session: DummySession, last_seen: tuple[datetime.datetime, str] | None + ) -> list[WorkflowRunRow]: + return [] + + def fake_delete_runs(session: DummySession, workflow_run_ids: list[str]) -> dict[str, int]: + raise AssertionError("should not delete") + + def fake_session_factory(engine: object | None = None) -> DummySession: # pragma: no cover - simple factory + return DummySession() + + monkeypatch.setattr(cleanup, "_load_batch", fake_load_batch) + monkeypatch.setattr(cleanup, "_delete_runs", fake_delete_runs) + monkeypatch.setattr(cleanup_module, "Session", fake_session_factory) + monkeypatch.setattr(cleanup_module, "db", type("DummyDB", (), {"engine": None})) + + cleanup.run() + + +def test_between_sets_window_bounds() -> None: + start_after = datetime.datetime(2024, 5, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 6, 1, 0, 0, 0) + cleanup = WorkflowRunCleanup(days=30, batch_size=10, start_after=start_after, end_before=end_before) + + assert cleanup.window_start == start_after + assert cleanup.window_end == end_before + + +def test_between_requires_both_boundaries() -> None: + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None) + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now()) + + +def test_between_requires_end_after_start() -> None: + start_after = datetime.datetime(2024, 6, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 5, 1, 0, 0, 0) + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=10, start_after=start_after, end_before=end_before)