diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index cb9c8921b6..d423081bf1 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -34,10 +34,12 @@ Example: ``` """ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from sqlalchemy.orm import Session + from core.workflow.entities.pause_reason import PauseReason from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -265,7 +267,11 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ ... - def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]: + def delete_runs_with_related( + self, + run_ids: Sequence[str], + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: """ Delete workflow runs and their related records (node executions, offloads, app logs, trigger logs, pauses, pause reasons). diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 95de006d98..1fa89e8f64 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -21,7 +21,7 @@ Implementation Notes: import logging import uuid -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from decimal import Decimal from typing import Any, cast @@ -52,7 +52,6 @@ from models.workflow import ( ) from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity -from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from repositories.types import ( AverageInteractionStats, DailyRunsStats, @@ -361,7 +360,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): return session.scalars(stmt).all() - def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]: + def delete_runs_with_related( + self, + run_ids: Sequence[str], + delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, + ) -> dict[str, int]: if not run_ids: return { "runs": 0, @@ -411,7 +414,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids))) pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0 - trigger_logs_deleted = SQLAlchemyWorkflowTriggerLogRepository(session).delete_by_run_ids(run_ids) + trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0 runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))) runs_deleted = cast(CursorResult, runs_result).rowcount or 0 diff --git a/api/services/clear_free_plan_expired_workflow_run_logs.py b/api/services/clear_free_plan_expired_workflow_run_logs.py index 6734c0b020..f51beef923 100644 --- a/api/services/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/clear_free_plan_expired_workflow_run_logs.py @@ -1,14 +1,15 @@ import datetime import logging -from collections.abc import Iterable +from collections.abc import Iterable, Sequence import click -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.billing_service import BillingService logger = logging.getLogger(__name__) @@ -21,7 +22,6 @@ class WorkflowRunCleanup: batch_size: int, start_after: datetime.datetime | None = None, end_before: datetime.datetime | None = None, - repo: APIWorkflowRunRepository | None = None, ): if (start_after is None) ^ (end_before is None): raise ValueError("start_after and end_before must be both set or both omitted.") @@ -35,15 +35,12 @@ class WorkflowRunCleanup: self.batch_size = batch_size self.billing_cache: dict[str, CloudPlan | None] = {} - if repo: - self.repo = repo - else: - # Lazy import to avoid circular dependency during module import - from repositories.factory import DifyAPIRepositoryFactory + # Lazy import to avoid circular dependency during module import + from repositories.factory import DifyAPIRepositoryFactory - self.repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository( - sessionmaker(bind=db.engine, expire_on_commit=False) - ) + self.workflow_run_repo: APIWorkflowRunRepository = DifyAPIRepositoryFactory.create_api_workflow_run_repository( + sessionmaker(bind=db.engine, expire_on_commit=False) + ) def run(self) -> None: click.echo( @@ -60,7 +57,7 @@ class WorkflowRunCleanup: last_seen: tuple[datetime.datetime, str] | None = None while True: - run_rows = self.repo.get_runs_batch_by_time_range( + run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( start_after=self.window_start, end_before=self.window_end, last_seen=last_seen, @@ -86,7 +83,10 @@ class WorkflowRunCleanup: continue try: - counts = self.repo.delete_runs_with_related(free_run_ids) + counts = self.workflow_run_repo.delete_runs_with_related( + free_run_ids, + delete_trigger_logs=self._delete_trigger_logs, + ) except Exception: logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) raise @@ -143,3 +143,7 @@ class WorkflowRunCleanup: self.billing_cache[tenant_id] = plan return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX} + + def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int: + trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) + return trigger_repo.delete_by_run_ids(run_ids) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 90a3ef6985..8b81b45d67 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -232,11 +232,9 @@ class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): fake_trigger_repo = Mock() fake_trigger_repo.delete_by_run_ids.return_value = 3 - with patch( - "repositories.sqlalchemy_api_workflow_run_repository.SQLAlchemyWorkflowTriggerLogRepository", - return_value=fake_trigger_repo, - ): - counts = repository.delete_runs_with_related(["run-1"]) + counts = repository.delete_runs_with_related( + ["run-1"], delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids) + ) fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"]) assert counts["trigger_logs"] == 3 diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py index 415bb9b67d..9ca9aa9208 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -3,6 +3,7 @@ from typing import Any import pytest +import repositories.factory as repo_factory_module from services import clear_free_plan_expired_workflow_run_logs as cleanup_module from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup @@ -42,15 +43,24 @@ class FakeRepo: self.call_idx += 1 return batch - def delete_runs_with_related(self, run_ids: list[str]) -> dict[str, int]: + def delete_runs_with_related(self, run_ids: list[str], delete_trigger_logs=None) -> dict[str, int]: self.deleted.append(list(run_ids)) result = self.delete_result.copy() result["runs"] = len(run_ids) return result +def create_cleanup(monkeypatch: pytest.MonkeyPatch, repo: FakeRepo, **kwargs: Any) -> WorkflowRunCleanup: + monkeypatch.setattr( + repo_factory_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + classmethod(lambda _cls, session_maker: repo), + ) + return WorkflowRunCleanup(**kwargs) + + def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None: - cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) @@ -66,7 +76,7 @@ def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) - def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None: - cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX @@ -83,7 +93,7 @@ def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None: - cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) monkeypatch.setattr( @@ -107,7 +117,7 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: ] ] ) - cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX @@ -126,7 +136,7 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None: cutoff = datetime.datetime.now() repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]]) - cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True) monkeypatch.setattr( @@ -140,36 +150,38 @@ def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None assert repo.deleted == [] -def test_run_exits_on_empty_batch() -> None: - cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([])) +def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None: + cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10) cleanup.run() -def test_between_sets_window_bounds() -> None: +def test_between_sets_window_bounds(monkeypatch: pytest.MonkeyPatch) -> None: start_after = datetime.datetime(2024, 5, 1, 0, 0, 0) end_before = datetime.datetime(2024, 6, 1, 0, 0, 0) - cleanup = WorkflowRunCleanup( - days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([]) + cleanup = create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=start_after, end_before=end_before ) assert cleanup.window_start == start_after assert cleanup.window_end == end_before -def test_between_requires_both_boundaries() -> None: +def test_between_requires_both_boundaries(monkeypatch: pytest.MonkeyPatch) -> None: with pytest.raises(ValueError): - WorkflowRunCleanup( - days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None, repo=FakeRepo([]) + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None ) with pytest.raises(ValueError): - WorkflowRunCleanup( - days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now(), repo=FakeRepo([]) + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now() ) -def test_between_requires_end_after_start() -> None: +def test_between_requires_end_after_start(monkeypatch: pytest.MonkeyPatch) -> None: start_after = datetime.datetime(2024, 6, 1, 0, 0, 0) end_before = datetime.datetime(2024, 5, 1, 0, 0, 0) with pytest.raises(ValueError): - WorkflowRunCleanup(days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([])) + create_cleanup( + monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=start_after, end_before=end_before + )