mirror of https://github.com/langgenius/dify.git
refactor to repo layer
This commit is contained in:
parent
22443df772
commit
46d824b17f
|
|
@ -253,6 +253,25 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_runs_batch_for_cleanup(
|
||||
self,
|
||||
start_after: datetime | None,
|
||||
end_before: datetime,
|
||||
last_seen: tuple[datetime, str] | None,
|
||||
batch_size: int,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Fetch a batch of workflow runs within a time window using keyset pagination for cleanup.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
|
||||
"""
|
||||
Delete workflow runs and their related records (node executions, offloads, app logs,
|
||||
trigger logs, pauses, pause reasons).
|
||||
"""
|
||||
...
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
|
|||
|
|
@ -40,8 +40,17 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
|||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowPauseReason, WorkflowRun
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import (
|
||||
WorkflowAppLog,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowPauseReason,
|
||||
WorkflowRun,
|
||||
)
|
||||
from models.workflow import (
|
||||
WorkflowPause as WorkflowPauseModel,
|
||||
)
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.types import (
|
||||
|
|
@ -314,6 +323,115 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id)
|
||||
return total_deleted
|
||||
|
||||
def get_runs_batch_for_cleanup(
|
||||
self,
|
||||
start_after: datetime | None,
|
||||
end_before: datetime,
|
||||
last_seen: tuple[datetime, str] | None,
|
||||
batch_size: int,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
with self._session_maker() as session:
|
||||
stmt = (
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.created_at < end_before)
|
||||
.order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
if start_after:
|
||||
stmt = stmt.where(WorkflowRun.created_at >= start_after)
|
||||
|
||||
if last_seen:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
WorkflowRun.created_at > last_seen[0],
|
||||
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
|
||||
)
|
||||
)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
|
||||
if not run_ids:
|
||||
return {
|
||||
"runs": 0,
|
||||
"node_executions": 0,
|
||||
"offloads": 0,
|
||||
"app_logs": 0,
|
||||
"trigger_logs": 0,
|
||||
"pauses": 0,
|
||||
"pause_reasons": 0,
|
||||
}
|
||||
|
||||
with self._session_maker() as session:
|
||||
node_execution_ids = session.scalars(
|
||||
select(WorkflowNodeExecutionModel.id).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
offloads_deleted = 0
|
||||
if node_execution_ids:
|
||||
offloads_deleted = (
|
||||
session.query(WorkflowNodeExecutionOffload)
|
||||
.where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
node_executions_deleted = 0
|
||||
if node_execution_ids:
|
||||
node_executions_deleted = (
|
||||
session.query(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
app_logs_deleted = (
|
||||
session.query(WorkflowAppLog)
|
||||
.where(WorkflowAppLog.workflow_run_id.in_(run_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
pause_ids = session.scalars(
|
||||
select(WorkflowPauseModel.id).where(WorkflowPauseModel.workflow_run_id.in_(run_ids))
|
||||
).all()
|
||||
pause_reasons_deleted = 0
|
||||
pauses_deleted = 0
|
||||
|
||||
if pause_ids:
|
||||
pause_reasons_deleted = (
|
||||
session.query(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
pauses_deleted = (
|
||||
session.query(WorkflowPauseModel)
|
||||
.where(WorkflowPauseModel.id.in_(pause_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
trigger_logs_deleted = (
|
||||
session.query(WorkflowTriggerLog)
|
||||
.where(WorkflowTriggerLog.workflow_run_id.in_(run_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
runs_deleted = (
|
||||
session.query(WorkflowRun).where(WorkflowRun.id.in_(run_ids)).delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"runs": runs_deleted,
|
||||
"node_executions": node_executions_deleted,
|
||||
"offloads": offloads_deleted,
|
||||
"app_logs": app_logs_deleted,
|
||||
"trigger_logs": trigger_logs_deleted,
|
||||
"pauses": pauses_deleted,
|
||||
"pause_reasons": pause_reasons_deleted,
|
||||
}
|
||||
|
||||
def create_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class BillingService:
|
|||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, dict]:
|
||||
def get_info_bulk(cls, tenant_ids: Sequence[str]) -> dict[str, str]:
|
||||
"""
|
||||
Bulk billing info fetch via billing API.
|
||||
|
||||
|
|
@ -39,17 +39,17 @@ class BillingService:
|
|||
Returns:
|
||||
Mapping of tenant_id -> plan
|
||||
"""
|
||||
|
||||
results: dict[str, dict] = {}
|
||||
results: dict[str, str] = {}
|
||||
|
||||
chunk_size = 200
|
||||
for i in range(0, len(tenant_ids), chunk_size):
|
||||
chunk = tenant_ids[i : i + chunk_size]
|
||||
try:
|
||||
resp = cls._send_request("POST", "/subscription/plan/batch", json={"tenant_ids": chunk})
|
||||
data = resp.get("data", {}) if isinstance(resp, dict) else {}
|
||||
if data:
|
||||
results.update(data)
|
||||
data = resp.get("data", {})
|
||||
for tenant_id, plan in data.items():
|
||||
if isinstance(plan, str):
|
||||
results[tenant_id] = plan
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch billing info batch for tenants: %s", chunk)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -1,30 +1,19 @@
|
|||
import datetime
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Iterable
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models import WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.workflow import WorkflowNodeExecutionOffload, WorkflowPause, WorkflowPauseReason
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.billing_service import BillingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkflowRunRow:
|
||||
id: str
|
||||
tenant_id: str
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
class WorkflowRunCleanup:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -32,6 +21,7 @@ class WorkflowRunCleanup:
|
|||
batch_size: int,
|
||||
start_after: datetime.datetime | None = None,
|
||||
end_before: datetime.datetime | None = None,
|
||||
repo: APIWorkflowRunRepository | None = None,
|
||||
):
|
||||
if (start_after is None) ^ (end_before is None):
|
||||
raise ValueError("start_after and end_before must be both set or both omitted.")
|
||||
|
|
@ -45,6 +35,9 @@ class WorkflowRunCleanup:
|
|||
|
||||
self.batch_size = batch_size
|
||||
self.billing_cache: dict[str, CloudPlan | None] = {}
|
||||
self.repo = repo or DifyAPIRepositoryFactory.create_api_workflow_run_repository(
|
||||
sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
click.echo(
|
||||
|
|
@ -61,46 +54,48 @@ class WorkflowRunCleanup:
|
|||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
|
||||
while True:
|
||||
with Session(db.engine) as session:
|
||||
run_rows = self._load_batch(session, last_seen)
|
||||
if not run_rows:
|
||||
break
|
||||
run_rows = self.repo.get_runs_batch_for_cleanup(
|
||||
start_after=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if not run_rows:
|
||||
break
|
||||
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_run_ids)
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_run_ids)
|
||||
|
||||
if not free_run_ids:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid)",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
counts = self._delete_runs(session, free_run_ids)
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
if not free_run_ids:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
counts = self.repo.delete_runs_with_related(free_run_ids)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
|
|
@ -114,28 +109,6 @@ class WorkflowRunCleanup:
|
|||
|
||||
click.echo(click.style(summary_message, fg="white"))
|
||||
|
||||
def _load_batch(self, session: Session, last_seen: tuple[datetime.datetime, str] | None) -> list[WorkflowRunRow]:
|
||||
stmt = (
|
||||
select(WorkflowRun.id, WorkflowRun.tenant_id, WorkflowRun.created_at)
|
||||
.where(WorkflowRun.created_at < self.window_end)
|
||||
.order_by(WorkflowRun.created_at.asc(), WorkflowRun.id.asc())
|
||||
.limit(self.batch_size)
|
||||
)
|
||||
|
||||
if self.window_start:
|
||||
stmt = stmt.where(WorkflowRun.created_at >= self.window_start)
|
||||
|
||||
if last_seen:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
WorkflowRun.created_at > last_seen[0],
|
||||
sa.and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
|
||||
)
|
||||
)
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
return [WorkflowRunRow(id=row.id, tenant_id=row.tenant_id, created_at=row.created_at) for row in rows]
|
||||
|
||||
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
||||
tenant_id_list = list(tenant_ids)
|
||||
uncached_tenants = [tenant_id for tenant_id in tenant_id_list if tenant_id not in self.billing_cache]
|
||||
|
|
@ -161,68 +134,3 @@ class WorkflowRunCleanup:
|
|||
self.billing_cache[tenant_id] = plan
|
||||
|
||||
return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX}
|
||||
|
||||
def _delete_runs(self, session: Session, workflow_run_ids: Sequence[str]) -> dict[str, int]:
|
||||
node_execution_ids = session.scalars(
|
||||
select(WorkflowNodeExecutionModel.id).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
offloads_deleted = 0
|
||||
if node_execution_ids:
|
||||
offloads_deleted = (
|
||||
session.query(WorkflowNodeExecutionOffload)
|
||||
.where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
node_executions_deleted = 0
|
||||
if node_execution_ids:
|
||||
node_executions_deleted = (
|
||||
session.query(WorkflowNodeExecutionModel)
|
||||
.where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
app_logs_deleted = (
|
||||
session.query(WorkflowAppLog)
|
||||
.where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
pause_ids = session.scalars(
|
||||
select(WorkflowPause.id).where(WorkflowPause.workflow_run_id.in_(workflow_run_ids))
|
||||
).all()
|
||||
pause_reasons_deleted = 0
|
||||
pauses_deleted = 0
|
||||
|
||||
if pause_ids:
|
||||
pause_reasons_deleted = (
|
||||
session.query(WorkflowPauseReason)
|
||||
.where(WorkflowPauseReason.pause_id.in_(pause_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
pauses_deleted = (
|
||||
session.query(WorkflowPause).where(WorkflowPause.id.in_(pause_ids)).delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
trigger_logs_deleted = (
|
||||
session.query(WorkflowTriggerLog)
|
||||
.where(WorkflowTriggerLog.workflow_run_id.in_(workflow_run_ids))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
runs_deleted = (
|
||||
session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
return {
|
||||
"runs": runs_deleted,
|
||||
"node_executions": node_executions_deleted,
|
||||
"offloads": offloads_deleted,
|
||||
"app_logs": app_logs_deleted,
|
||||
"trigger_logs": trigger_logs_deleted,
|
||||
"pauses": pauses_deleted,
|
||||
"pause_reasons": pause_reasons_deleted,
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue