mirror of https://github.com/langgenius/dify.git
refactor the repo and service
This commit is contained in:
parent
da9a28b9e2
commit
231ecc1bfe
|
|
@ -34,10 +34,12 @@ Example:
|
|||
```
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
|
@ -265,7 +267,11 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
|
||||
def delete_runs_with_related(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Delete workflow runs and their related records (node executions, offloads, app logs,
|
||||
trigger logs, pauses, pause reasons).
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ Implementation Notes:
|
|||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
|
@ -52,7 +52,6 @@ from models.workflow import (
|
|||
)
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from repositories.types import (
|
||||
AverageInteractionStats,
|
||||
DailyRunsStats,
|
||||
|
|
@ -361,7 +360,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def delete_runs_with_related(self, run_ids: Sequence[str]) -> dict[str, int]:
|
||||
def delete_runs_with_related(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if not run_ids:
|
||||
return {
|
||||
"runs": 0,
|
||||
|
|
@ -411,7 +414,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||
pauses_result = session.execute(delete(WorkflowPauseModel).where(WorkflowPauseModel.id.in_(pause_ids)))
|
||||
pauses_deleted = cast(CursorResult, pauses_result).rowcount or 0
|
||||
|
||||
trigger_logs_deleted = SQLAlchemyWorkflowTriggerLogRepository(session).delete_by_run_ids(run_ids)
|
||||
trigger_logs_deleted = delete_trigger_logs(session, run_ids) if delete_trigger_logs else 0
|
||||
|
||||
runs_result = session.execute(delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)))
|
||||
runs_deleted = cast(CursorResult, runs_result).rowcount or 0
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import datetime
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.billing_service import BillingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -21,7 +22,6 @@ class WorkflowRunCleanup:
|
|||
batch_size: int,
|
||||
start_after: datetime.datetime | None = None,
|
||||
end_before: datetime.datetime | None = None,
|
||||
repo: APIWorkflowRunRepository | None = None,
|
||||
):
|
||||
if (start_after is None) ^ (end_before is None):
|
||||
raise ValueError("start_after and end_before must be both set or both omitted.")
|
||||
|
|
@ -35,15 +35,12 @@ class WorkflowRunCleanup:
|
|||
|
||||
self.batch_size = batch_size
|
||||
self.billing_cache: dict[str, CloudPlan | None] = {}
|
||||
if repo:
|
||||
self.repo = repo
|
||||
else:
|
||||
# Lazy import to avoid circular dependency during module import
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
# Lazy import to avoid circular dependency during module import
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
self.repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
|
||||
sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
self.workflow_run_repo: APIWorkflowRunRepository = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
|
||||
sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
click.echo(
|
||||
|
|
@ -60,7 +57,7 @@ class WorkflowRunCleanup:
|
|||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
|
||||
while True:
|
||||
run_rows = self.repo.get_runs_batch_by_time_range(
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_after=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
|
|
@ -86,7 +83,10 @@ class WorkflowRunCleanup:
|
|||
continue
|
||||
|
||||
try:
|
||||
counts = self.repo.delete_runs_with_related(free_run_ids)
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_run_ids,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
|
@ -143,3 +143,7 @@ class WorkflowRunCleanup:
|
|||
self.billing_cache[tenant_id] = plan
|
||||
|
||||
return {tenant_id for tenant_id in tenant_id_list if self.billing_cache.get(tenant_id) == CloudPlan.SANDBOX}
|
||||
|
||||
def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
return trigger_repo.delete_by_run_ids(run_ids)
|
||||
|
|
|
|||
|
|
@ -232,11 +232,9 @@ class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
|||
fake_trigger_repo = Mock()
|
||||
fake_trigger_repo.delete_by_run_ids.return_value = 3
|
||||
|
||||
with patch(
|
||||
"repositories.sqlalchemy_api_workflow_run_repository.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=fake_trigger_repo,
|
||||
):
|
||||
counts = repository.delete_runs_with_related(["run-1"])
|
||||
counts = repository.delete_runs_with_related(
|
||||
["run-1"], delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids)
|
||||
)
|
||||
|
||||
fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"])
|
||||
assert counts["trigger_logs"] == 3
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any
|
|||
|
||||
import pytest
|
||||
|
||||
import repositories.factory as repo_factory_module
|
||||
from services import clear_free_plan_expired_workflow_run_logs as cleanup_module
|
||||
from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
|
||||
|
|
@ -42,15 +43,24 @@ class FakeRepo:
|
|||
self.call_idx += 1
|
||||
return batch
|
||||
|
||||
def delete_runs_with_related(self, run_ids: list[str]) -> dict[str, int]:
|
||||
def delete_runs_with_related(self, run_ids: list[str], delete_trigger_logs=None) -> dict[str, int]:
|
||||
self.deleted.append(list(run_ids))
|
||||
result = self.delete_result.copy()
|
||||
result["runs"] = len(run_ids)
|
||||
return result
|
||||
|
||||
|
||||
def create_cleanup(monkeypatch: pytest.MonkeyPatch, repo: FakeRepo, **kwargs: Any) -> WorkflowRunCleanup:
|
||||
monkeypatch.setattr(
|
||||
repo_factory_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
classmethod(lambda _cls, session_maker: repo),
|
||||
)
|
||||
return WorkflowRunCleanup(**kwargs)
|
||||
|
||||
|
||||
def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
|
|
@ -66,7 +76,7 @@ def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -
|
|||
|
||||
|
||||
def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX
|
||||
|
|
@ -83,7 +93,7 @@ def test_filter_free_tenants_bulk_mixed(monkeypatch: pytest.MonkeyPatch) -> None
|
|||
|
||||
|
||||
def test_filter_free_tenants_bulk_failure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -107,7 +117,7 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
]
|
||||
]
|
||||
)
|
||||
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo)
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
cleanup.billing_cache["t_free"] = cleanup_module.CloudPlan.SANDBOX
|
||||
|
|
@ -126,7 +136,7 @@ def test_run_deletes_only_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(batches=[[FakeRun("run-paid", "t_paid", cutoff)]])
|
||||
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=repo)
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", True)
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -140,36 +150,38 @@ def test_run_skips_when_no_free_tenants(monkeypatch: pytest.MonkeyPatch) -> None
|
|||
assert repo.deleted == []
|
||||
|
||||
|
||||
def test_run_exits_on_empty_batch() -> None:
|
||||
cleanup = WorkflowRunCleanup(days=30, batch_size=10, repo=FakeRepo([]))
|
||||
def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup = create_cleanup(monkeypatch, repo=FakeRepo([]), days=30, batch_size=10)
|
||||
|
||||
cleanup.run()
|
||||
|
||||
|
||||
def test_between_sets_window_bounds() -> None:
|
||||
def test_between_sets_window_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
start_after = datetime.datetime(2024, 5, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 6, 1, 0, 0, 0)
|
||||
cleanup = WorkflowRunCleanup(
|
||||
days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([])
|
||||
cleanup = create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=start_after, end_before=end_before
|
||||
)
|
||||
|
||||
assert cleanup.window_start == start_after
|
||||
assert cleanup.window_end == end_before
|
||||
|
||||
|
||||
def test_between_requires_both_boundaries() -> None:
|
||||
def test_between_requires_both_boundaries(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowRunCleanup(
|
||||
days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None, repo=FakeRepo([])
|
||||
create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=datetime.datetime.now(), end_before=None
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowRunCleanup(
|
||||
days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now(), repo=FakeRepo([])
|
||||
create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=None, end_before=datetime.datetime.now()
|
||||
)
|
||||
|
||||
|
||||
def test_between_requires_end_after_start() -> None:
|
||||
def test_between_requires_end_after_start(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
start_after = datetime.datetime(2024, 6, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 5, 1, 0, 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowRunCleanup(days=30, batch_size=10, start_after=start_after, end_before=end_before, repo=FakeRepo([]))
|
||||
create_cleanup(
|
||||
monkeypatch, repo=FakeRepo([]), days=30, batch_size=10, start_after=start_after, end_before=end_before
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue