diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index fd547c78ba..dea5c781d0 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -253,6 +253,25 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... + def get_runs_batch_for_cleanup( + self, + start_after: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + ) -> Sequence[WorkflowRun]: + """ + Fetch a batch of workflow runs within a time window using keyset pagination for cleanup. + """ + ... + + def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]: + """ + Delete workflow runs and their related records (node executions, offloads, app logs, + trigger logs, pauses, pause reasons). + """ + ... + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index b172c6a3ac..84ba62076d 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -40,8 +40,17 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPauseReason, + WorkflowRun, +) +from models.workflow import ( + WorkflowPause as WorkflowPauseModel, +) from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( @@ -314,6 +323,115 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted + def get_runs_batch_for_cleanup( + self, + start_after: datetime | None, + end_before: datetime, + last_seen: tuple[datetime, str] | None, + batch_size: int, + ) -> Sequence[WorkflowRun]: + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where(WorkflowRun.created_at < end_before) + .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) + .limit(batch_size) + ) + + if start_after: + stmt = stmt.where(WorkflowRun.created_at >= start_after) + + if last_seen: + stmt = stmt.where( + or_( + WorkflowRun.created_at > last_seen[0], + and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]), + ) + ) + + return session.scalars(stmt).all() + + def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]: + if not run_ids: + return { + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + with self._session_maker() as session: + node_execution_ids = session.scalars( + select(WorkflowNodeExecutionModel.id).where( + WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids) + ) + ).all() + + offloads_deleted = 0 + if node_execution_ids: + offloads_deleted = ( + session.query(WorkflowNodeExecutionOffload) + .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)) + .delete(synchronize_session=False) + ) + + node_executions_deleted = 0 + if node_execution_ids: + node_executions_deleted = ( + session.query(WorkflowNodeExecutionModel) + .where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) + .delete(synchronize_session=False) + ) + + app_logs_deleted = ( + session.query(WorkflowAppLog) + .where(WorkflowAppLog.workflow_run_id.in_(run_ids)) + .delete(synchronize_session=False) + ) + + pause_ids = session.scalars( + select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids)) + ).all() + pause_reasons_deleted = 0 + pauses_deleted = 0 + + if pause_ids: + pause_reasons_deleted = ( + session.query(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)).delete( + synchronize_session=False + ) + ) + pauses_deleted = ( + session.query(WorkflowPauseModel) + .where(WorkflowPauseModel.id.in_(pause_ids)) + .delete(synchronize_session=False) + ) + + trigger_logs_deleted = ( + session.query(WorkflowTriggerLog) + .where(WorkflowTriggerLog.workflow_run_id.in_(run_ids)) + .delete(synchronize_session=False) + ) + + runs_deleted = ( + session.query(WorkflowRun).where(WorkflowRun.id.in_(run_ids)).delete(synchronize_session=False) + ) + + session.commit() + + return { + "runs": runs_deleted, + "node_executions": node_executions_deleted, + "offloads": offloads_deleted, + "app_logs": app_logs_deleted, + "trigger_logs": trigger_logs_deleted, + "pauses": pauses_deleted, + "pause_reasons": pause_reasons_deleted, + } + def create_workflow_pause( self, workflow_run_id: str, diff --git a/api/services/billing_service.py b/api/services/billing_service.py index be7bc55f13..cd7b5fc389 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -30,7 +30,7 @@ class BillingService: return billing_info @classmethod - def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, dict]: + def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, str]: """ Bulk billing info fetch via billing API. @@ -39,17 +39,17 @@ class BillingService: Returns: Mapping of tenant_id -> plan """ - - results: dict[str, dict] = {} + results: dict[str, str] = {} chunk_size = 200 for i in range(0, len(tenant_ids), chunk_size): chunk = tenant_ids[i : i + chunk_size] try: resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk}) - data = resp.get("data", {}) if isinstance(resp, dict) else {} - if data: - results.update(data) + data = resp.get("data", {}) + for tenant_id, plan in data.items(): + if isinstance(plan, str): + results[tenant_id] = plan except Exception: logger.exception("Failed to fetch billing info batch for tenants: %s", chunk) continue diff --git a/api/services/clear_free_plan_expired_workflow_run_logs.py b/api/services/clear_free_plan_expired_workflow_run_logs.py index cca5ab15d2..68ba7787d3 100644 --- a/api/services/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/clear_free_plan_expired_workflow_run_logs.py @@ -1,30 +1,19 @@ import datetime import logging -from collections.abc import Iterable, Sequence -from dataclasses import dataclass +from collections.abc import Iterable import click -import sqlalchemy as sa -from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from enums.cloud_plan import CloudPlan from extensions.ext_database import db -from models import WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun -from models.trigger import WorkflowTriggerLog -from models.workflow import WorkflowNodeExecutionOffload, WorkflowPause, WorkflowPauseReason +from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class WorkflowRunRow: - id: str - tenant_id: str - created_at: datetime.datetime - - class WorkflowRunCleanup: def __init__( self, @@ -32,6 +21,7 @@ class WorkflowRunCleanup: batch_size: int, start_after: datetime.datetime | None = None, end_before: datetime.datetime | None = None, + repo: APIWorkflowRunRepository | None = None, ): if (start_after is None) ^ (end_before is None): raise ValueError("start_after and end_before must be both set or both omitted.") @@ -45,6 +35,9 @@ class WorkflowRunCleanup: self.batch_size = batch_size self.billing_cache: dict[str, CloudPlan | None] = {} + self.repo = repo or DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) def run(self) -> None: click.echo( @@ -61,46 +54,48 @@ class WorkflowRunCleanup: last_seen: tuple[datetime.datetime, str] | None = None while True: - with Session(db.engine) as session: - run_rows = self._load_batch(session, last_seen) - if not run_rows: - break + run_rows = self.repo.get_runs_batch_for_cleanup( + start_after=self.window_start, + end_before=self.window_end, + last_seen=last_seen, + batch_size=self.batch_size, + ) + if not run_rows: + break - batch_index += 1 - last_seen = (run_rows[-1].created_at, run_rows[-1].id) - tenant_ids = {row.tenant_id for row in run_rows} - free_tenants = self._filter_free_tenants(tenant_ids) - free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants] - paid_or_skipped = len(run_rows) - len(free_run_ids) + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) + tenant_ids = {row.tenant_id for row in run_rows} + free_tenants = self._filter_free_tenants(tenant_ids) + free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_run_ids) - if not free_run_ids: - click.echo( - click.style( - f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid)", - fg="yellow", - ) - ) - continue - - try: - counts = self._delete_runs(session, free_run_ids) - session.commit() - except Exception: - session.rollback() - logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) - raise - - total_runs_deleted += counts["runs"] + if not free_run_ids: click.echo( click.style( - f"[batch #{batch_index}] deleted runs: {counts['runs']} " - f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " - f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " - f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " - f"skipped {paid_or_skipped} paid/unknown", - fg="green", + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)", + fg="yellow", ) ) + continue + + try: + counts = self.repo.delete_runs_with_related(free_run_ids) + except Exception: + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] + click.echo( + click.style( + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", + ) + ) if self.window_start: summary_message = ( @@ -114,28 +109,6 @@ class WorkflowRunCleanup: click.echo(click.style(summary_message, fg="white")) - def _load_batch(self, session: Session, last_seen: tuple[datetime.datetime, str] | None) -> list[WorkflowRunRow]: - stmt = ( - select(WorkflowRun.id, WorkflowRun.tenant_id, WorkflowRun.created_at) - .where(WorkflowRun.created_at < self.window_end) - .order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc()) - .limit(self.batch_size) - ) - - if self.window_start: - stmt = stmt.where(WorkflowRun.created_at >= self.window_start) - - if last_seen: - stmt = stmt.where( - sa.or_( - WorkflowRun.created_at > last_seen[0], - sa.and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]), - ) - ) - - rows = session.execute(stmt).all() - return [WorkflowRunRow(id=row.id, tenant_id=row.tenant_id, created_at=row.created_at) for row in rows] - def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: tenant_id_list = list(tenant_ids) uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache] @@ -161,68 +134,3 @@ class WorkflowRunCleanup: self.billing_cache[tenant_id] = plan return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX} - - def _delete_runs(self, session: Session, workflow_run_ids: Sequence[str]) -> dict[str, int]: - node_execution_ids = session.scalars( - select(WorkflowNodeExecutionModel.id).where( - WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) - ) - ).all() - - offloads_deleted = 0 - if node_execution_ids: - offloads_deleted = ( - session.query(WorkflowNodeExecutionOffload) - .where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids)) - .delete(synchronize_session=False) - ) - - node_executions_deleted = 0 - if node_execution_ids: - node_executions_deleted = ( - session.query(WorkflowNodeExecutionModel) - .where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) - .delete(synchronize_session=False) - ) - - app_logs_deleted = ( - session.query(WorkflowAppLog) - .where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)) - .delete(synchronize_session=False) - ) - - pause_ids = session.scalars( - select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(workflow_run_ids)) - ).all() - pause_reasons_deleted = 0 - pauses_deleted = 0 - - if pause_ids: - pause_reasons_deleted = ( - session.query(WorkflowPauseReason) - .where(WorkflowPauseReason.pause_id.in_(pause_ids)) - .delete(synchronize_session=False) - ) - pauses_deleted = ( - session.query(WorkflowPause).where(WorkflowPause.id.in_(pause_ids)).delete(synchronize_session=False) - ) - - trigger_logs_deleted = ( - session.query(WorkflowTriggerLog) - .where(WorkflowTriggerLog.workflow_run_id.in_(workflow_run_ids)) - .delete(synchronize_session=False) - ) - - runs_deleted = ( - session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) - ) - - return { - "runs": runs_deleted, - "node_executions": node_executions_deleted, - "offloads": offloads_deleted, - "app_logs": app_logs_deleted, - "trigger_logs": trigger_logs_deleted, - "pauses": pauses_deleted, - "pause_reasons": pause_reasons_deleted, - }