From dd949a23e1721b4d057ad3e4c83e25c209c87608 Mon Sep 17 00:00:00 2001 From: hjlarry Date: Fri, 12 Dec 2025 10:19:16 +0800 Subject: [PATCH] add clean sandbox workflow runs --- api/.env.example | 1 + api/commands.py | 41 ++++ api/configs/feature/__init__.py | 4 + api/extensions/ext_celery.py | 7 + api/extensions/ext_commands.py | 2 + ...d7c23e_add_workflow_runs_created_at_idx.py | 29 +++ .../api_workflow_run_repository.py | 19 ++ .../sqlalchemy_api_workflow_run_repository.py | 118 +++++++++++- ...alchemy_workflow_trigger_log_repository.py | 20 +- api/schedule/clean_workflow_runs_task.py | 30 +++ api/services/billing_service.py | 31 ++++ ...ear_free_plan_expired_workflow_run_logs.py | 145 +++++++++++++++ ..._sqlalchemy_api_workflow_run_repository.py | 62 +++++++ ...alchemy_workflow_trigger_log_repository.py | 31 ++++ ...ear_free_plan_expired_workflow_run_logs.py | 175 ++++++++++++++++++ docker/.env.example | 1 + docker/docker-compose.yaml | 1 + 17 files changed, 714 insertions(+), 3 deletions(-) create mode 100644 api/migrations/versions/2025_12_10_1504-8a7f2ad7c23e_add_workflow_runs_created_at_idx.py create mode 100644 api/schedule/clean_workflow_runs_task.py create mode 100644 api/services/clear_free_plan_expired_workflow_run_logs.py create mode 100644 api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py create mode 100644 api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py diff --git a/api/.env.example b/api/.env.example index 516a119d98..9fb4211d18 100644 --- a/api/.env.example +++ b/api/.env.example @@ -552,6 +552,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true diff --git a/api/commands.py b/api/commands.py index a8d89ac200..9a990459c0 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,4 +1,5 @@ import base64 +import datetime import json import logging import secrets @@ -41,6 +42,7 @@ from models.provider_ids import DatasourceProviderID, ToolProviderID from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService +from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration @@ -852,6 +854,45 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[ click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) +@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") +@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.") +@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") +@click.option( + "--start-after", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-after.", +) +def clean_workflow_runs( + days: int, + batch_size: int, + start_after: datetime.datetime | None, + end_before: datetime.datetime | None, +): + """ + Clean workflow runs and related workflow data for free tenants. + """ + if (start_after is None) ^ (end_before is None): + raise click.UsageError("--start-after and --end-before must be provided together.") + + click.echo(click.style("Starting workflow run cleanup.", fg="white")) + + WorkflowRunCleanup( + days=days, + batch_size=batch_size, + start_after=start_after, + end_before=end_before, + ).run() + + click.echo(click.style("Workflow run cleanup completed.", fg="green")) + + @click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") @click.command("clear-orphaned-file-records", help="Clear orphaned file records.") def clear_orphaned_file_records(force: bool): diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index e16ca52f46..0aaf5dcce8 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1096,6 +1096,10 @@ class CeleryScheduleTasksConfig(BaseSettings): description="Enable clean messages task", default=False, ) + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field( + description="Enable scheduled workflow run cleanup task", + default=False, + ) ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( description="Enable mail clean document notify task", default=False, diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 5cf4984709..c19763372d 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -160,6 +160,13 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", "schedule": crontab(minute="0", hour="2"), } + if dify_config.ENABLE_WORKFLOW_RUN_CLEANUP_TASK: + # for saas only + imports.append("schedule.clean_workflow_runs_task") + beat_schedule["clean_workflow_runs_task"] = { + "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task", + "schedule": crontab(minute="0", hour="0"), + } if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: imports.append("schedule.workflow_schedule_task") beat_schedule["workflow_schedule_task"] = { diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 71a63168a5..6f6322827c 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + clean_workflow_runs, cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, @@ -54,6 +55,7 @@ def init_app(app: DifyApp): setup_datasource_oauth_client, transform_datasource_credentials, install_rag_pipeline_plugins, + clean_workflow_runs, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/migrations/versions/2025_12_10_1504-8a7f2ad7c23e_add_workflow_runs_created_at_idx.py b/api/migrations/versions/2025_12_10_1504-8a7f2ad7c23e_add_workflow_runs_created_at_idx.py new file mode 100644 index 0000000000..7968429ca8 --- /dev/null +++ b/api/migrations/versions/2025_12_10_1504-8a7f2ad7c23e_add_workflow_runs_created_at_idx.py @@ -0,0 +1,29 @@ +"""Add index on workflow_runs.created_at + +Revision ID: 8a7f2ad7c23e +Revises: d57accd375ae +Create Date: 2025-12-10 15:04:00.000000 +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "8a7f2ad7c23e" +down_revision = "d57accd375ae" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("workflow_runs", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("workflow_runs_created_at_idx"), + ["created_at"], + unique=False, + ) + + +def downgrade(): + with op.batch_alter_table("workflow_runs", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("workflow_runs_created_at_idx")) diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index fd547c78ba..cb9c8921b6 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -253,6 +253,25 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_runs_batch_by_time_range( + self, + start_after: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + ) -> Sequence[WorkflowRun]: + """ + Fetch a batch of workflow runs within a time window using keyset pagination. + """ + ... + + def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]: + """ + Delete workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons). + """ + ... + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index b172c6a3ac..95de006d98 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -40,10 +40,19 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun +from models.workflow import ( + WorkflowAppLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPauseReason, + WorkflowRun, +) +from models.workflow import ( + WorkflowPause as WorkflowPauseModel, +) from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from repositories.types import ( AverageInteractionStats, DailyRunsStats, @@ -314,6 +323,111 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted + def get_runs_batch_by_time_range( + self, + start_after: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + ) -> Sequence[WorkflowRun]: + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.created_at < end_before, + WorkflowRun.status.in_( + [ + WorkflowExecutionStatus.SUCCEEDED.value, + WorkflowExecutionStatus.FAILED.value, + WorkflowExecutionStatus.STOPPED.value, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, + ] + ), + ) + .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) + .limit(batch_size) + ) + + if start_after: + stmt = stmt.where(WorkflowRun.created_at >= start_after) + + if last_seen: + stmt = stmt.where( + or_( + WorkflowRun.created_at > last_seen[0], + and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]), + ) + ) + + return session.scalars(stmt).all() + + def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]: + if not run_ids: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + node_execution_ids = session.scalars( + select(WorkflowNodeExecutionModel.id).where(WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)) + ).all() + + offloads_deleted = 0 + if node_execution_ids: + offloads_result = session.execute( + delete(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + ) + offloads_deleted = cast(CursorResult, offloads_result).rowcount or 0 + + node_executions_deleted = 0 + if node_execution_ids: + node_executions_result = session.execute( + delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) + ) + node_executions_deleted = cast(CursorResult, node_executions_result).rowcount or 0 + + app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) + app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 + + pause_ids = session.scalars( + select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + ).all() + pause_reasons_deleted = 0 + pauses_deleted = 0 + + if pause_ids: + pause_reasons_result = session.execute( + delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)) + ) + pause_reasons_deleted = cast(CursorResult, pause_reasons_result).rowcount or 0 + pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids))) + pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 + + trigger_logs_deleted = SQLAlchemyWorkflowTriggerLogRepository(session).delete_by_run_ids(run_ids) + + runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))) + runs_deleted = cast(CursorResult, runs_result).rowcount or 0 + + session.commit() + + return { + "runs": runs_deleted, + "node_executions": node_executions_deleted, + "offloads": offloads_deleted, + "app_logs": app_logs_deleted, + "trigger_logs": trigger_logs_deleted, + "pauses": pauses_deleted, + "pause_reasons": pause_reasons_deleted, + } + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py index 0d67e286b0..d01c35e5ab 100644 --- a/api/repositories/sqlalchemy_workflow_trigger_log_repository.py +++ b/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @@ -4,8 +4,10 @@ SQLAlchemy implementation of WorkflowTriggerLogRepository. from collections.abc import Sequence from datetime import UTC, datetime, timedelta +from typing import cast -from sqlalchemy import and_, select +from sqlalchemy import and_, delete, select +from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from models.enums import WorkflowTriggerStatus @@ -84,3 +86,19 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository): ) return list(self.session.scalars(query).all()) + + def delete_by_run_ids(self, run_ids: Sequence[str]) -> int: + """ + Delete trigger logs associated with the given workflow run ids. + + Args: + run_ids: Collection of workflow run identifiers. + + Returns: + Number of rows deleted. + """ + if not run_ids: + return 0 + + result = self.session.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))) + return cast(CursorResult, result).rowcount or 0 diff --git a/api/schedule/clean_workflow_runs_task.py b/api/schedule/clean_workflow_runs_task.py new file mode 100644 index 0000000000..b59dc7f823 --- /dev/null +++ b/api/schedule/clean_workflow_runs_task.py @@ -0,0 +1,30 @@ +import click + +import app +from configs import dify_config +from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + +CLEANUP_QUEUE = "retention" + + +@app.celery.task(queue=CLEANUP_QUEUE) +def clean_workflow_runs_task() -> None: + """ + Scheduled cleanup for workflow runs and related records (sandbox tenants only). + """ + click.echo( + click.style( + f"Scheduled workflow run cleanup starting: cutoff={dify_config.WORKFLOW_LOG_RETENTION_DAYS} days, " + f"batch={dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE}", + fg="green", + ) + ) + + WorkflowRunCleanup( + days=dify_config.WORKFLOW_LOG_RETENTION_DAYS, + batch_size=dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE, + start_after=None, + end_before=None, + ).run() + + click.echo(click.style("Scheduled workflow run cleanup finished.", fg="green")) diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 54e1c9d285..cd7b5fc389 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,4 +1,6 @@ +import logging import os +from collections.abc import Sequence from typing import Literal import httpx @@ -11,6 +13,8 @@ from extensions.ext_redis import redis_client from libs.helper import RateLimiter from models import Account, TenantAccountJoin, TenantAccountRole +logger = logging.getLogger(__name__) + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") @@ -25,6 +29,33 @@ class BillingService: billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info + @classmethod + def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, str]: + """ + Bulk billing info fetch via billing API. + + Payload: {"tenant_ids": ["t1", "t2", ...]} (max 200 per request) + + Returns: + Mapping of tenant_id -> plan + """ + results: dict[str, str] = {} + + chunk_size = 200 + for i in range(0, len(tenant_ids), chunk_size): + chunk = tenant_ids[i : i + chunk_size] + try: + resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) + data = resp.get("data", {}) + for tenant_id, plan in data.items(): + if isinstance(plan, str): + results[tenant_id] = plan + except Exception: + logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) + continue + + return results + @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} diff --git a/api/services/clear_free_plan_expired_workflow_run_logs.py b/api/services/clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..6734c0b020 --- /dev/null +++ b/api/services/clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,145 @@ +import datetime +import logging +from collections.abc import Iterable + +import click +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from enums.cloud_plan import CloudPlan +from extensions.ext_database import db +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from services.billing_service import BillingService + +logger = logging.getLogger(__name__) + + +class WorkflowRunCleanup: + def __init__( + self, + days: int, + batch_size: int, + start_after: datetime.datetime | None = None, + end_before: datetime.datetime | None = None, + repo: APIWorkflowRunRepository | None = None, + ): + if (start_after is None) ^ (end_before is None): + raise ValueError("start_after and end_before must be both set or both omitted.") + + computed_cutoff = datetime.datetime.now() - datetime.timedelta(days=days) + self.window_start = start_after + self.window_end = end_before or computed_cutoff + + if self.window_start and self.window_end <= self.window_start: + raise ValueError("end_before must be greater than start_after.") + + self.batch_size = batch_size + self.billing_cache: dict[str, CloudPlan | None] = {} + if repo: + self.repo = repo + else: + # Lazy import to avoid circular dependency during module import + from repositories.factory import DifyAPIRepositoryFactory + + self.repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) + + def run(self) -> None: + click.echo( + click.style( + f"Cleaning workflow runs " + f"{'between ' + self.window_start.isoformat() + ' and ' if self.window_start else 'before '}" + f"{self.window_end.isoformat()} (batch={self.batch_size})", + fg="white", + ) + ) + + total_runs_deleted = 0 + batch_index = 0 + last_seen: tuple[datetime.datetime, str] | None = None + + while True: + run_rows = self.repo.get_runs_batch_by_time_range( + start_after=self.window_start, + end_before=self.window_end, + last_seen=last_seen, + batch_size=self.batch_size, + ) + if not run_rows: + break + + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) + tenant_ids = {row.tenant_id for row in run_rows} + free_tenants = self._filter_free_tenants(tenant_ids) + free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_run_ids) + + if not free_run_ids: + click.echo( + click.style( + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)", + fg="yellow", + ) + ) + continue + + try: + counts = self.repo.delete_runs_with_related(free_run_ids) + except Exception: + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] + click.echo( + click.style( + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", + ) + ) + + if self.window_start: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}" + ) + + click.echo(click.style(summary_message, fg="white")) + + def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: + if not dify_config.BILLING_ENABLED: + return set(tenant_ids) + + tenant_id_list = list(tenant_ids) + uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache] + + if uncached_tenants: + try: + bulk_info = BillingService.get_info_bulk(uncached_tenants) + except Exception: + bulk_info = {} + logger.exception("Failed to fetch billing plans in bulk for tenants: %s", uncached_tenants) + + for tenant_id in uncached_tenants: + plan: CloudPlan | None = None + info = bulk_info.get(tenant_id) + if info: + try: + plan = CloudPlan(info) + except Exception: + logger.exception("Failed to parse billing plan for tenant %s", tenant_id) + else: + logger.warning("Missing billing info for tenant %s in bulk resp; treating as non-free", tenant_id) + + self.billing_cache[tenant_id] = plan + + return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX} diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 0c34676252..90a3ef6985 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -4,6 +4,7 @@ from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session, sessionmaker from core.workflow.enums import WorkflowExecutionStatus @@ -104,6 +105,42 @@ class TestDifyAPISQLAlchemyWorkflowRunRepository: return pause +class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_get_runs_batch_by_time_range_filters_terminal_statuses( + self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock + ): + scalar_result = Mock() + scalar_result.all.return_value = [] + mock_session.scalars.return_value = scalar_result + + repository.get_runs_batch_by_time_range( + start_after=None, + end_before=datetime(2024, 1, 1), + last_seen=None, + batch_size=50, + ) + + stmt = mock_session.scalars.call_args[0][0] + compiled_sql = str( + stmt.compile( + dialect=postgresql.dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + + assert "workflow_runs.status" in compiled_sql + for status in ( + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.STOPPED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + ): + assert f"'{status.value}'" in compiled_sql + + assert "'running'" not in compiled_sql + assert "'paused'" not in compiled_sql + + class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test create_workflow_pause method.""" @@ -181,6 +218,31 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): ) +class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): + def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): + node_ids_result = Mock() + node_ids_result.all.return_value = [] + pause_ids_result = Mock() + pause_ids_result.all.return_value = [] + mock_session.scalars.side_effect = [node_ids_result, pause_ids_result] + + # app_logs delete, runs delete + mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)] + + fake_trigger_repo = Mock() + fake_trigger_repo.delete_by_run_ids.return_value = 3 + + with patch( + "repositories.sqlalchemy_api_workflow_run_repository.SQLAlchemyWorkflowTriggerLogRepository", + return_value=fake_trigger_repo, + ): + counts = repository.delete_runs_with_related(["run-1"]) + + fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["trigger_logs"] == 3 + assert counts["runs"] == 1 + + class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): """Test resume_workflow_pause method.""" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py new file mode 100644 index 0000000000..d409618211 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_workflow_trigger_log_repository.py @@ -0,0 +1,31 @@ +from unittest.mock import Mock + +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import Session + +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository + + +def test_delete_by_run_ids_executes_delete(): + session = Mock(spec=Session) + session.execute.return_value = Mock(rowcount=2) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids(["run-1", "run-2"]) + + stmt = session.execute.call_args[0][0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) + assert "workflow_trigger_logs" in compiled_sql + assert "'run-1'" in compiled_sql + assert "'run-2'" in compiled_sql + assert deleted == 2 + + +def test_delete_by_run_ids_empty_short_circuits(): + session = Mock(spec=Session) + repo = SQLAlchemyWorkflowTriggerLogRepository(session) + + deleted = repo.delete_by_run_ids([]) + + session.execute.assert_not_called() + assert deleted == 0 diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..415bb9b67d --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,175 @@ +import datetime +from typing import Any + +import pytest + +from services import clear_free_plan_expired_workflow_run_logs as cleanup_module +from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +class FakeRun: + def __init__(self, run_id: str, tenant_id: str, created_at: datetime.datetime) -> None: + self.id = run_id + self.tenant_id = tenant_id + self.created_at = created_at + + +class FakeRepo: + def __init__(self, batches: list[list[FakeRun]], delete_result: dict[str, int] | None = None) -> None: + self.batches = batches + self.call_idx = 0 + self.deleted: list[list[str]] = [] + self.delete_result = delete_result or { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + def get_runs_batch_by_time_range( + self, + start_after: datetime.datetime | None, + end_before: datetime.datetime, + last_seen: tuple[datetime.datetime, str] | None, + batch_size: int, + ) -> list[FakeRun]: + if self.call_idx >= len(self.batches): + return [] + batch = self.batches[self.call_idx] + self.call_idx += 1 + return batch + + def delete_runs_with_related(self, run_ids: list[str]) -> dict[str, int]: + self.deleted.append(list(run_ids)) + result = self.delete_result.copy() + result["runs"] = len(run_ids) + return result + + +def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + def fail_bulk(_: list[str]) -> dict[str, dict[str, Any]]: + raise RuntimeError("should not call") + + monkeypatch.setattr(cleanup_module.BillingService, "get_info_bulk", staticmethod(fail_bulk)) + + tenants = {"t1", "t2"} + free = cleanup._filter_free_tenants(tenants) + + assert free == tenants + + +def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX + cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "sandbox")), + ) + + free = cleanup._filter_free_tenants({"t_free", "t_paid", "t_missing"}) + + assert free == {"t_free", "t_missing"} + + +def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod(lambda tenant_ids: (_ for _ in ()).throw(RuntimeError("boom"))), + ) + + free = cleanup._filter_free_tenants({"t1", "t2"}) + + assert free == set() + + +def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[ + [ + FakeRun("run-free", "t_free", cutoff), + FakeRun("run-paid", "t_paid", cutoff), + ] + ] + ) + cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX + cleanup.billing_cache["t_paid"] = cleanup_module.CloudPlan.TEAM + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "sandbox")), + ) + + cleanup.run() + + assert repo.deleted == [["run-free"]] + + +def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]]) + cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo) + + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + cleanup_module.BillingService, + "get_info_bulk", + staticmethod(lambda tenant_ids: dict.fromkeys(tenant_ids, "team")), + ) + + cleanup.run() + + assert repo.deleted == [] + + +def test_run_exits_on_empty_batch() -> None: + cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) + + cleanup.run() + + +def test_between_sets_window_bounds() -> None: + start_after = datetime.datetime(2024, 5, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 6, 1, 0, 0, 0) + cleanup = WorkflowRunCleanup( + days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([]) + ) + + assert cleanup.window_start == start_after + assert cleanup.window_end == end_before + + +def test_between_requires_both_boundaries() -> None: + with pytest.raises(ValueError): + WorkflowRunCleanup( + days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None, repo=FakeRepo([]) + ) + with pytest.raises(ValueError): + WorkflowRunCleanup( + days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now(), repo=FakeRepo([]) + ) + + +def test_between_requires_end_after_start() -> None: + start_after = datetime.datetime(2024, 6, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 5, 1, 0, 0, 0) + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([])) diff --git a/docker/.env.example b/docker/.env.example index 04088b72a8..40f9323c43 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1422,6 +1422,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false ENABLE_CREATE_TIDB_SERVERLESS_TASK=false ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false ENABLE_CLEAN_MESSAGES=false +ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false ENABLE_DATASETS_QUEUE_MONITOR=false ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 68f5726797..257a383990 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -627,6 +627,7 @@ x-shared-env: &shared-api-worker-env ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false} ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-false} ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-false} + ENABLE_WORKFLOW_RUN_CLEANUP_TASK: ${ENABLE_WORKFLOW_RUN_CLEANUP_TASK:-false} ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false} ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false} ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true}