diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index d423081bf1..22c3ea7130 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -269,7 +269,8 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): def delete_runs_with_related( self, - run_ids: Sequence[str], + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, ) -> dict[str, int]: """ diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 7e2173acdd..9d0a42f518 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -7,13 +7,13 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. from collections.abc import Sequence from datetime import datetime -from typing import cast +from typing import TypedDict, cast -from sqlalchemy import asc, delete, desc, select +from sqlalchemy import asc, delete, desc, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from models.workflow import WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository @@ -290,3 +290,58 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut result = cast(CursorResult, session.execute(stmt)) session.commit() return result.rowcount + + class _RunContext(TypedDict): + run_id: str + tenant_id: str + app_id: str + workflow_id: str + triggered_from: str + + @staticmethod + def delete_by_runs(session: Session, runs: Sequence[_RunContext]) -> tuple[int, int]: + """ + Delete node executions (and offloads) for the given workflow runs using indexed columns. + + Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id) + by filtering on those columns with tuple IN. + """ + if not runs: + return 0, 0 + + tuple_values = [ + (run["tenant_id"], run["app_id"], run["workflow_id"], run["triggered_from"], run["run_id"]) for run in runs + ] + + node_execution_ids = session.scalars( + select(WorkflowNodeExecutionModel.id).where( + tuple_( + WorkflowNodeExecutionModel.tenant_id, + WorkflowNodeExecutionModel.app_id, + WorkflowNodeExecutionModel.workflow_id, + WorkflowNodeExecutionModel.triggered_from, + WorkflowNodeExecutionModel.workflow_run_id, + ).in_(tuple_values) + ) + ).all() + + if not node_execution_ids: + return 0, 0 + + offloads_deleted = cast( + CursorResult, + session.execute( + delete(WorkflowNodeExecutionOffload).where( + WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) + ) + ), + ).rowcount or 0 + + node_executions_deleted = cast( + CursorResult, + session.execute( + delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) + ), + ).rowcount or 0 + + return node_executions_deleted, offloads_deleted diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 1fa89e8f64..f081124a80 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -42,8 +42,6 @@ from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom from models.workflow import ( WorkflowAppLog, - WorkflowNodeExecutionModel, - WorkflowNodeExecutionOffload, WorkflowPauseReason, WorkflowRun, ) @@ -362,10 +360,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): def delete_runs_with_related( self, - run_ids: Sequence[str], + runs: Sequence[WorkflowRun], + delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, ) -> dict[str, int]: - if not run_ids: + if not runs: return { "runs": 0, "node_executions": 0, @@ -377,25 +376,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): } with self._session_maker() as session: - node_execution_ids = session.scalars( - select(WorkflowNodeExecutionModel.id).where(WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)) - ).all() - - offloads_deleted = 0 - if node_execution_ids: - offloads_result = session.execute( - delete(WorkflowNodeExecutionOffload).where( - WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids) - ) - ) - offloads_deleted = cast(CursorResult, offloads_result).rowcount or 0 - - node_executions_deleted = 0 - if node_execution_ids: - node_executions_result = session.execute( - delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids)) - ) - node_executions_deleted = cast(CursorResult, node_executions_result).rowcount or 0 + run_ids = [run.id for run in runs] + if delete_node_executions: + node_executions_deleted, offloads_deleted = delete_node_executions(session, runs) + else: + node_executions_deleted, offloads_deleted = 0, 0 app_logs_result = session.execute(delete(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(run_ids))) app_logs_deleted = cast(CursorResult, app_logs_result).rowcount or 0 diff --git a/api/services/clear_free_plan_expired_workflow_run_logs.py b/api/services/clear_free_plan_expired_workflow_run_logs.py index f51beef923..cd5f0d208d 100644 --- a/api/services/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/clear_free_plan_expired_workflow_run_logs.py @@ -9,6 +9,9 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.billing_service import BillingService @@ -70,10 +73,10 @@ class WorkflowRunCleanup: last_seen = (run_rows[-1].created_at, run_rows[-1].id) tenant_ids = {row.tenant_id for row in run_rows} free_tenants = self._filter_free_tenants(tenant_ids) - free_run_ids = [row.id for row in run_rows if row.tenant_id in free_tenants] - paid_or_skipped = len(run_rows) - len(free_run_ids) + free_runs = [row for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_runs) - if not free_run_ids: + if not free_runs: click.echo( click.style( f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)", @@ -84,7 +87,8 @@ class WorkflowRunCleanup: try: counts = self.workflow_run_repo.delete_runs_with_related( - free_run_ids, + free_runs, + delete_node_executions=self._delete_node_executions, delete_trigger_logs=self._delete_trigger_logs, ) except Exception: @@ -147,3 +151,16 @@ class WorkflowRunCleanup: def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) return trigger_repo.delete_by_run_ids(run_ids) + + def _delete_node_executions(self, session: Session, runs: Sequence[object]) -> tuple[int, int]: + run_contexts = [ + { + "run_id": run.id, + "tenant_id": run.tenant_id, + "app_id": run.app_id, + "workflow_id": run.workflow_id, + "triggered_from": run.triggered_from, + } + for run in runs + ] + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 8b81b45d67..ef3ac29519 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -232,11 +232,16 @@ class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): fake_trigger_repo = Mock() fake_trigger_repo.delete_by_run_ids.return_value = 3 + run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") counts = repository.delete_runs_with_related( - ["run-1"], delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids) + [run], + delete_node_executions=lambda session, runs: (2, 1), + delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids), ) fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 assert counts["trigger_logs"] == 3 assert counts["runs"] == 1 diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py index 9ca9aa9208..913a8f7ff4 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -9,9 +9,20 @@ from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanu class FakeRun: - def __init__(self, run_id: str, tenant_id: str, created_at: datetime.datetime) -> None: + def __init__( + self, + run_id: str, + tenant_id: str, + created_at: datetime.datetime, + app_id: str = "app-1", + workflow_id: str = "wf-1", + triggered_from: str = "workflow-run", + ) -> None: self.id = run_id self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.triggered_from = triggered_from self.created_at = created_at @@ -43,10 +54,12 @@ class FakeRepo: self.call_idx += 1 return batch - def delete_runs_with_related(self, run_ids: list[str], delete_trigger_logs=None) -> dict[str, int]: - self.deleted.append(list(run_ids)) + def delete_runs_with_related(self, runs: list[FakeRun], + delete_node_executions=None, + delete_trigger_logs=None) -> dict[str, int]: + self.deleted.append([run.id for run in runs]) result = self.delete_result.copy() - result["runs"] = len(run_ids) + result["runs"] = len(runs) return result @@ -56,6 +69,7 @@ def create_cleanup(monkeypatch: pytest.MonkeyPatch, repo: FakeRepo, **kwargs: An "create_api_workflow_run_repository", classmethod(lambda _cls, session_maker: repo), ) + kwargs.pop("repo", None) return WorkflowRunCleanup(**kwargs)