refactor to repo layer

This commit is contained in:
hjlarry 2025-12-11 16:56:54 +08:00
parent 22443df772
commit 46d824b17f
4 changed files with 189 additions and 144 deletions

View File

@ -253,6 +253,25 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
def get_runs_batch_for_cleanup(
self,
start_after: datetime | None,
end_before: datetime,
last_seen: tuple[datetime, str] | None,
batch_size: int,
) -> Sequence[WorkflowRun]:
"""
Fetch a batch of workflow runs within a time window using keyset pagination for cleanup.
"""
...
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
"""
Delete workflow runs and their related records (node executions, offloads, app logs,
trigger logs, pauses, pause reasons).
"""
...
def create_workflow_pause(
self,
workflow_run_id: str,

View File

@ -40,8 +40,17 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.time_parser import get_time_threshold
from libs.uuid_utils import uuidv7
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowPauseReason, WorkflowRun
from models.trigger import WorkflowTriggerLog
from models.workflow import (
WorkflowAppLog,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowPauseReason,
WorkflowRun,
)
from models.workflow import (
WorkflowPause as WorkflowPauseModel,
)
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.types import (
@ -314,6 +323,115 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
return total_deleted
def get_runs_batch_for_cleanup(
self,
start_after: datetime | None,
end_before: datetime,
last_seen: tuple[datetime, str] | None,
batch_size: int,
) -> Sequence[WorkflowRun]:
with self._session_maker() as session:
stmt = (
select(WorkflowRun)
.where(WorkflowRun.created_at < end_before)
.order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
.limit(batch_size)
)
if start_after:
stmt = stmt.where(WorkflowRun.created_at >= start_after)
if last_seen:
stmt = stmt.where(
or_(
WorkflowRun.created_at > last_seen[0],
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
)
)
return session.scalars(stmt).all()
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
if not run_ids:
return {
"runs": 0,
"node_executions": 0,
"offloads": 0,
"app_logs": 0,
"trigger_logs": 0,
"pauses": 0,
"pause_reasons": 0,
}
with self._session_maker() as session:
node_execution_ids = session.scalars(
select(WorkflowNodeExecutionModel.id).where(
WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
)
).all()
offloads_deleted = 0
if node_execution_ids:
offloads_deleted = (
session.query(WorkflowNodeExecutionOffload)
.where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
.delete(synchronize_session=False)
)
node_executions_deleted = 0
if node_execution_ids:
node_executions_deleted = (
session.query(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
.delete(synchronize_session=False)
)
app_logs_deleted = (
session.query(WorkflowAppLog)
.where(WorkflowAppLog.workflow_run_id.in_(run_ids))
.delete(synchronize_session=False)
)
pause_ids = session.scalars(
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
).all()
pause_reasons_deleted = 0
pauses_deleted = 0
if pause_ids:
pause_reasons_deleted = (
session.query(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)).delete(
synchronize_session=False
)
)
pauses_deleted = (
session.query(WorkflowPauseModel)
.where(WorkflowPauseModel.id.in_(pause_ids))
.delete(synchronize_session=False)
)
trigger_logs_deleted = (
session.query(WorkflowTriggerLog)
.where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
.delete(synchronize_session=False)
)
runs_deleted = (
session.query(WorkflowRun).where(WorkflowRun.id.in_(run_ids)).delete(synchronize_session=False)
)
session.commit()
return {
"runs": runs_deleted,
"node_executions": node_executions_deleted,
"offloads": offloads_deleted,
"app_logs": app_logs_deleted,
"trigger_logs": trigger_logs_deleted,
"pauses": pauses_deleted,
"pause_reasons": pause_reasons_deleted,
}
def create_workflow_pause(
self,
workflow_run_id: str,

View File

@ -30,7 +30,7 @@ class BillingService:
return billing_info
@classmethod
def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, dict]:
def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, str]:
"""
Bulk billing info fetch via billing API.
@ -39,17 +39,17 @@ class BillingService:
Returns:
Mapping of tenant_id -> plan
"""
results: dict[str, dict] = {}
results: dict[str, str] = {}
chunk_size = 200
for i in range(0, len(tenant_ids), chunk_size):
chunk = tenant_ids[i : i + chunk_size]
try:
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
data = resp.get("data", {}) if isinstance(resp, dict) else {}
if data:
results.update(data)
data = resp.get("data", {})
for tenant_id, plan in data.items():
if isinstance(plan, str):
results[tenant_id] = plan
except Exception:
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
continue

View File

@ -1,30 +1,19 @@
import datetime
import logging
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from collections.abc import Iterable
import click
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
from models.trigger import WorkflowTriggerLog
from models.workflow import WorkflowNodeExecutionOffload, WorkflowPause, WorkflowPauseReason
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class WorkflowRunRow:
id: str
tenant_id: str
created_at: datetime.datetime
class WorkflowRunCleanup:
def __init__(
self,
@ -32,6 +21,7 @@ class WorkflowRunCleanup:
batch_size: int,
start_after: datetime.datetime | None = None,
end_before: datetime.datetime | None = None,
repo: APIWorkflowRunRepository | None = None,
):
if (start_after is None) ^ (end_before is None):
raise ValueError("start_after and end_before must be both set or both omitted.")
@ -45,6 +35,9 @@ class WorkflowRunCleanup:
self.batch_size = batch_size
self.billing_cache: dict[str, CloudPlan | None] = {}
self.repo = repo or DifyAPIRepositoryFactory.create_api_workflow_run_repository(
sessionmaker(bind=db.engine, expire_on_commit=False)
)
def run(self) -> None:
click.echo(
@ -61,46 +54,48 @@ class WorkflowRunCleanup:
last_seen: tuple[datetime.datetime, str] | None = None
while True:
with Session(db.engine) as session:
run_rows = self._load_batch(session, last_seen)
if not run_rows:
break
run_rows = self.repo.get_runs_batch_for_cleanup(
start_after=self.window_start,
end_before=self.window_end,
last_seen=last_seen,
batch_size=self.batch_size,
)
if not run_rows:
break
batch_index += 1
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
tenant_ids = {row.tenant_id for row in run_rows}
free_tenants = self._filter_free_tenants(tenant_ids)
free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants]
paid_or_skipped = len(run_rows) - len(free_run_ids)
batch_index += 1
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
tenant_ids = {row.tenant_id for row in run_rows}
free_tenants = self._filter_free_tenants(tenant_ids)
free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants]
paid_or_skipped = len(run_rows) - len(free_run_ids)
if not free_run_ids:
click.echo(
click.style(
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid)",
fg="yellow",
)
)
continue
try:
counts = self._delete_runs(session, free_run_ids)
session.commit()
except Exception:
session.rollback()
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
raise
total_runs_deleted += counts["runs"]
if not free_run_ids:
click.echo(
click.style(
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
f"skipped {paid_or_skipped} paid/unknown",
fg="green",
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)",
fg="yellow",
)
)
continue
try:
counts = self.repo.delete_runs_with_related(free_run_ids)
except Exception:
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
raise
total_runs_deleted += counts["runs"]
click.echo(
click.style(
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
f"skipped {paid_or_skipped} paid/unknown",
fg="green",
)
)
if self.window_start:
summary_message = (
@ -114,28 +109,6 @@ class WorkflowRunCleanup:
click.echo(click.style(summary_message, fg="white"))
def _load_batch(self, session: Session, last_seen: tuple[datetime.datetime, str] | None) -> list[WorkflowRunRow]:
stmt = (
select(WorkflowRun.id, WorkflowRun.tenant_id, WorkflowRun.created_at)
.where(WorkflowRun.created_at < self.window_end)
.order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
.limit(self.batch_size)
)
if self.window_start:
stmt = stmt.where(WorkflowRun.created_at >= self.window_start)
if last_seen:
stmt = stmt.where(
sa.or_(
WorkflowRun.created_at > last_seen[0],
sa.and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
)
)
rows = session.execute(stmt).all()
return [WorkflowRunRow(id=row.id, tenant_id=row.tenant_id, created_at=row.created_at) for row in rows]
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
tenant_id_list = list(tenant_ids)
uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache]
@ -161,68 +134,3 @@ class WorkflowRunCleanup:
self.billing_cache[tenant_id] = plan
return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX}
def _delete_runs(self, session: Session, workflow_run_ids: Sequence[str]) -> dict[str, int]:
node_execution_ids = session.scalars(
select(WorkflowNodeExecutionModel.id).where(
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
)
).all()
offloads_deleted = 0
if node_execution_ids:
offloads_deleted = (
session.query(WorkflowNodeExecutionOffload)
.where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
.delete(synchronize_session=False)
)
node_executions_deleted = 0
if node_execution_ids:
node_executions_deleted = (
session.query(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
.delete(synchronize_session=False)
)
app_logs_deleted = (
session.query(WorkflowAppLog)
.where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids))
.delete(synchronize_session=False)
)
pause_ids = session.scalars(
select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(workflow_run_ids))
).all()
pause_reasons_deleted = 0
pauses_deleted = 0
if pause_ids:
pause_reasons_deleted = (
session.query(WorkflowPauseReason)
.where(WorkflowPauseReason.pause_id.in_(pause_ids))
.delete(synchronize_session=False)
)
pauses_deleted = (
session.query(WorkflowPause).where(WorkflowPause.id.in_(pause_ids)).delete(synchronize_session=False)
)
trigger_logs_deleted = (
session.query(WorkflowTriggerLog)
.where(WorkflowTriggerLog.workflow_run_id.in_(workflow_run_ids))
.delete(synchronize_session=False)
)
runs_deleted = (
session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
)
return {
"runs": runs_deleted,
"node_executions": node_executions_deleted,
"offloads": offloads_deleted,
"app_logs": app_logs_deleted,
"trigger_logs": trigger_logs_deleted,
"pauses": pauses_deleted,
"pause_reasons": pause_reasons_deleted,
}