diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index ab116611b8..004bd7cad9 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -291,7 +291,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut session.commit() return result.rowcount - class _RunContext(TypedDict): + class RunContext(TypedDict): run_id: str tenant_id: str app_id: str @@ -299,7 +299,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut triggered_from: str @staticmethod - def delete_by_runs(session: Session, runs: Sequence[_RunContext]) -> tuple[int, int]: + def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]: """ Delete node executions (and offloads) for the given workflow runs using indexed columns. diff --git a/api/services/clear_free_plan_expired_workflow_run_logs.py b/api/services/clear_free_plan_expired_workflow_run_logs.py index cd5f0d208d..17cce15311 100644 --- a/api/services/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/clear_free_plan_expired_workflow_run_logs.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from models.workflow import WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_node_execution_repository import ( DifyAPISQLAlchemyWorkflowNodeExecutionRepository, @@ -25,6 +26,7 @@ class WorkflowRunCleanup: batch_size: int, start_after: datetime.datetime | None = None, end_before: datetime.datetime | None = None, + workflow_run_repo: APIWorkflowRunRepository | None = None, ): if (start_after is None) ^ (end_before is None): raise ValueError("start_after and end_before must be both set or both omitted.") @@ -38,12 +40,16 @@ class WorkflowRunCleanup: self.batch_size = batch_size self.billing_cache: dict[str, CloudPlan | None] = {} - # Lazy import to avoid circular dependency during module import - from repositories.factory import DifyAPIRepositoryFactory + if workflow_run_repo: + self.workflow_run_repo = workflow_run_repo + else: + # Lazy import to avoid circular dependencies during module import + from repositories.factory import DifyAPIRepositoryFactory - self.workflow_run_repo: APIWorkflowRunRepository = DifyAPIRepositoryFactory.create_api_workflow_run_repository( - sessionmaker(bind=db.engine, expire_on_commit=False) - ) + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self.workflow_run_repo: APIWorkflowRunRepository = ( + DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + ) def run(self) -> None: click.echo( @@ -152,8 +158,8 @@ class WorkflowRunCleanup: trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session) return trigger_repo.delete_by_run_ids(run_ids) - def _delete_node_executions(self, session: Session, runs: Sequence[object]) -> tuple[int, int]: - run_contexts = [ + def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: + run_contexts: list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext] = [ { "run_id": run.id, "tenant_id": run.tenant_id, diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py index 3728386e34..a1685fcfb0 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -3,7 +3,6 @@ from typing import Any import pytest -import repositories.factory as repo_factory_module from services import clear_free_plan_expired_workflow_run_logs as cleanup_module from services.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup @@ -64,13 +63,7 @@ class FakeRepo: def create_cleanup(monkeypatch: pytest.MonkeyPatch, repo: FakeRepo, **kwargs: Any) -> WorkflowRunCleanup: - monkeypatch.setattr( - repo_factory_module.DifyAPIRepositoryFactory, - "create_api_workflow_run_repository", - classmethod(lambda _cls, session_maker: repo), - ) - kwargs.pop("repo", None) - return WorkflowRunCleanup(**kwargs) + return WorkflowRunCleanup(workflow_run_repo=repo, **kwargs) def test_filter_free_tenants_billing_disabled(monkeypatch: pytest.MonkeyPatch) -> None: