refactor the repo and service

This commit is contained in:
hjlarry 2025-12-12 12:03:29 +08:00
parent da9a28b9e2
commit 231ecc1bfe
5 changed files with 65 additions and 42 deletions

View File

@ -34,10 +34,12 @@ Example:
```
"""
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from datetime import datetime
from typing import Protocol
from sqlalchemy.orm import Session
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -265,7 +267,11 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
"""
...
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
def delete_runs_with_related(
self,
run_ids: Sequence[str],
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
) -> dict[str, int]:
"""
Delete workflow runs and their related records (node executions, offloads, app logs,
trigger logs, pauses, pause reasons).

View File

@ -21,7 +21,7 @@ Implementation Notes:
import logging
import uuid
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from datetime import datetime
from decimal import Decimal
from typing import Any, cast
@ -52,7 +52,6 @@ from models.workflow import (
)
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from repositories.types import (
AverageInteractionStats,
DailyRunsStats,
@ -361,7 +360,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
return session.scalars(stmt).all()
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
def delete_runs_with_related(
self,
run_ids: Sequence[str],
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
) -> dict[str, int]:
if not run_ids:
return {
"runs": 0,
@ -411,7 +414,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids)))
pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
trigger_logs_deleted = SQLAlchemyWorkflowTriggerLogRepository(session).delete_by_run_ids(run_ids)
trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
runs_deleted = cast(CursorResult, runs_result).rowcount or 0

View File

@ -1,14 +1,15 @@
import datetime
import logging
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
import click
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
@ -21,7 +22,6 @@ class WorkflowRunCleanup:
batch_size: int,
start_after: datetime.datetime | None = None,
end_before: datetime.datetime | None = None,
repo: APIWorkflowRunRepository | None = None,
):
if (start_after is None) ^ (end_before is None):
raise ValueError("start_after and end_before must be both set or both omitted.")
@ -35,15 +35,12 @@ class WorkflowRunCleanup:
self.batch_size = batch_size
self.billing_cache: dict[str, CloudPlan | None] = {}
if repo:
self.repo = repo
else:
# Lazy import to avoid circular dependency during module import
from repositories.factory import DifyAPIRepositoryFactory
# Lazy import to avoid circular dependency during module import
from repositories.factory import DifyAPIRepositoryFactory
self.repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
sessionmaker(bind=db.engine, expire_on_commit=False)
)
self.workflow_run_repo: APIWorkflowRunRepository = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
sessionmaker(bind=db.engine, expire_on_commit=False)
)
def run(self) -> None:
click.echo(
@ -60,7 +57,7 @@ class WorkflowRunCleanup:
last_seen: tuple[datetime.datetime, str] | None = None
while True:
run_rows = self.repo.get_runs_batch_by_time_range(
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
start_after=self.window_start,
end_before=self.window_end,
last_seen=last_seen,
@ -86,7 +83,10 @@ class WorkflowRunCleanup:
continue
try:
counts = self.repo.delete_runs_with_related(free_run_ids)
counts = self.workflow_run_repo.delete_runs_with_related(
free_run_ids,
delete_trigger_logs=self._delete_trigger_logs,
)
except Exception:
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
raise
@ -143,3 +143,7 @@ class WorkflowRunCleanup:
self.billing_cache[tenant_id] = plan
return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX}
def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
return trigger_repo.delete_by_run_ids(run_ids)

View File

@ -232,11 +232,9 @@ class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
fake_trigger_repo = Mock()
fake_trigger_repo.delete_by_run_ids.return_value = 3
with patch(
"repositories.sqlalchemy_api_workflow_run_repository.SQLAlchemyWorkflowTriggerLogRepository",
return_value=fake_trigger_repo,
):
counts = repository.delete_runs_with_related(["run-1"])
counts = repository.delete_runs_with_related(
["run-1"], delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids)
)
fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"])
assert counts["trigger_logs"] == 3

View File

@ -3,6 +3,7 @@ from typing import Any
import pytest
import repositories.factory as repo_factory_module
from services import clear_free_plan_expired_workflow_run_logs as cleanup_module
from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
@ -42,15 +43,24 @@ class FakeRepo:
self.call_idx += 1
return batch
def delete_runs_with_related(self, run_ids: list[str]) -> dict[str, int]:
def delete_runs_with_related(self, run_ids: list[str], delete_trigger_logs=None) -> dict[str, int]:
self.deleted.append(list(run_ids))
result = self.delete_result.copy()
result["runs"] = len(run_ids)
return result
def create_cleanup(monkeypatch: pytest.MonkeyPatch, repo: FakeRepo, **kwargs: Any) -> WorkflowRunCleanup:
monkeypatch.setattr(
repo_factory_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
classmethod(lambda _cls, session_maker: repo),
)
return WorkflowRunCleanup(**kwargs)
def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
@ -66,7 +76,7 @@ def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -
def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX
@ -83,7 +93,7 @@ def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None
def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
monkeypatch.setattr(
@ -107,7 +117,7 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
]
]
)
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo)
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX
@ -126,7 +136,7 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
cutoff = datetime.datetime.now()
repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]])
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo)
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
monkeypatch.setattr(
@ -140,36 +150,38 @@ def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None
assert repo.deleted == []
def test_run_exits_on_empty_batch() -> None:
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None:
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
cleanup.run()
def test_between_sets_window_bounds() -> None:
def test_between_sets_window_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
start_after = datetime.datetime(2024, 5, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 6, 1, 0, 0, 0)
cleanup = WorkflowRunCleanup(
days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([])
cleanup = create_cleanup(
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=start_after, end_before=end_before
)
assert cleanup.window_start == start_after
assert cleanup.window_end == end_before
def test_between_requires_both_boundaries() -> None:
def test_between_requires_both_boundaries(monkeypatch: pytest.MonkeyPatch) -> None:
with pytest.raises(ValueError):
WorkflowRunCleanup(
days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None, repo=FakeRepo([])
create_cleanup(
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None
)
with pytest.raises(ValueError):
WorkflowRunCleanup(
days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now(), repo=FakeRepo([])
create_cleanup(
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now()
)
def test_between_requires_end_after_start() -> None:
def test_between_requires_end_after_start(monkeypatch: pytest.MonkeyPatch) -> None:
start_after = datetime.datetime(2024, 6, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 5, 1, 0, 0, 0)
with pytest.raises(ValueError):
WorkflowRunCleanup(days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([]))
create_cleanup(
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=start_after, end_before=end_before
)