From 5aa4e23f54ea1bb143f40a051112d5696e3173af Mon Sep 17 00:00:00 2001 From: carlos4s <71615127+carlos4s@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:21:28 -0500 Subject: [PATCH] refactor(api): use sessionmaker in end user, retention & cleanup services (#34765) --- .../clear_free_plan_tenant_expired_logs.py | 9 ++-- api/services/end_user_service.py | 11 ++--- .../conversation/messages_clean_service.py | 10 ++--- ...est_clear_free_plan_tenant_expired_logs.py | 45 +++++++++---------- 4 files changed, 33 insertions(+), 42 deletions(-) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index b4a7fa051f..b0f7efaccd 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -120,7 +120,7 @@ class ClearFreePlanTenantExpiredLogs: apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() app_ids = [app.id for app in apps] while True: - with Session(db.engine).no_autoflush as session: + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: messages = ( session.query(Message) .where( @@ -152,7 +152,6 @@ class ClearFreePlanTenantExpiredLogs: ).delete(synchronize_session=False) cls._clear_message_related_tables(session, tenant_id, message_ids) - session.commit() click.echo( click.style( @@ -161,7 +160,7 @@ class ClearFreePlanTenantExpiredLogs: ) while True: - with Session(db.engine).no_autoflush as session: + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: conversations = ( session.query(Conversation) .where( @@ -190,7 +189,6 @@ class ClearFreePlanTenantExpiredLogs: session.query(Conversation).where( Conversation.id.in_(conversation_ids), ).delete(synchronize_session=False) - session.commit() click.echo( click.style( @@ -294,7 +292,7 @@ class ClearFreePlanTenantExpiredLogs: break while True: - with Session(db.engine).no_autoflush as session: + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: workflow_app_logs = ( session.query(WorkflowAppLog) .where( @@ -326,7 +324,6 @@ class ClearFreePlanTenantExpiredLogs: session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( synchronize_session=False ) - session.commit() click.echo( click.style( diff --git a/api/services/end_user_service.py b/api/services/end_user_service.py index 29ada270ec..749d8dbc30 100644 --- a/api/services/end_user_service.py +++ b/api/services/end_user_service.py @@ -2,7 +2,7 @@ import logging from collections.abc import Mapping from sqlalchemy import case, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db @@ -24,7 +24,7 @@ class EndUserService: when an end-user ID is known. """ - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: return session.scalar( select(EndUser) .where( @@ -54,7 +54,7 @@ class EndUserService: if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility # This single query approach is more efficient than separate queries end_user = session.scalar( @@ -82,7 +82,6 @@ class EndUserService: user_id, ) end_user.type = type - session.commit() else: # Create new end user if none exists end_user = EndUser( @@ -94,7 +93,6 @@ class EndUserService: external_user_id=user_id, ) session.add(end_user) - session.commit() return end_user @@ -135,7 +133,7 @@ class EndUserService: if not unique_app_ids: return result - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Fetch existing end users for all target apps in a single query existing_end_users: list[EndUser] = list( session.scalars( @@ -174,7 +172,6 @@ class EndUserService: ) session.add_all(new_end_users) - session.commit() for eu in new_end_users: result[eu.app_id] = eu diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py index 0e0dbab2d1..1e9f0bf149 100644 --- a/api/services/retention/conversation/messages_clean_service.py +++ b/api/services/retention/conversation/messages_clean_service.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, TypedDict, cast import sqlalchemy as sa from sqlalchemy import delete, select, tuple_ from sqlalchemy.engine import CursorResult -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from extensions.ext_database import db @@ -369,7 +369,7 @@ class MessagesCleanService: batch_deleted_messages = 0 # Step 1: Fetch a batch of messages using cursor - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: fetch_messages_start = time.monotonic() msg_stmt = ( select(Message.id, Message.app_id, Message.created_at) @@ -477,7 +477,7 @@ class MessagesCleanService: # Step 4: Batch delete messages and their relations if not self._dry_run: - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: delete_relations_start = time.monotonic() # Delete related records first self._batch_delete_message_relations(session, message_ids_to_delete) @@ -489,9 +489,7 @@ class MessagesCleanService: delete_result = cast(CursorResult, session.execute(delete_stmt)) messages_deleted = delete_result.rowcount delete_messages_ms = int((time.monotonic() - delete_messages_start) * 1000) - commit_start = time.monotonic() - session.commit() - commit_ms = int((time.monotonic() - commit_start) * 1000) + commit_ms = 0 stats["total_deleted"] += messages_deleted batch_deleted_messages = messages_deleted diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index f393a4b10b..3e989c55a3 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -275,48 +275,46 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) - msg_session_1.query.side_effect = lambda model: ( make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() ) - msg_session_1.commit.return_value = None - msg_session_2 = MagicMock() msg_session_2.query.side_effect = lambda model: ( make_query_with_batches([[]]) if model == service_module.Message else MagicMock() ) - msg_session_2.commit.return_value = None conv_session_1 = MagicMock() conv_session_1.query.side_effect = lambda model: ( make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() ) - conv_session_1.commit.return_value = None conv_session_2 = MagicMock() conv_session_2.query.side_effect = lambda model: ( make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() ) - conv_session_2.commit.return_value = None wal_session_1 = MagicMock() wal_session_1.query.side_effect = lambda model: ( make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() ) - wal_session_1.commit.return_value = None wal_session_2 = MagicMock() wal_session_2.query.side_effect = lambda model: ( make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() ) - wal_session_2.commit.return_value = None session_wrappers = [ - _session_wrapper_for_no_autoflush(msg_session_1), - _session_wrapper_for_no_autoflush(msg_session_2), - _session_wrapper_for_no_autoflush(conv_session_1), - _session_wrapper_for_no_autoflush(conv_session_2), - _session_wrapper_for_no_autoflush(wal_session_1), - _session_wrapper_for_no_autoflush(wal_session_2), + _sessionmaker_wrapper_for_begin(msg_session_1), + _sessionmaker_wrapper_for_begin(msg_session_2), + _sessionmaker_wrapper_for_begin(conv_session_1), + _sessionmaker_wrapper_for_begin(conv_session_2), + _sessionmaker_wrapper_for_begin(wal_session_1), + _sessionmaker_wrapper_for_begin(wal_session_2), ] - monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + def fake_sessionmaker(*args, **kwargs): + if kwargs.get("autoflush") is False: + return session_wrappers.pop(0) + return object() + + monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker) def fake_select(*_args, **_kwargs): stmt = MagicMock() @@ -333,8 +331,6 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) - run_repo = MagicMock() run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []] run_repo.delete_runs_by_ids.return_value = 1 - - monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) monkeypatch.setattr( service_module.DifyAPIRepositoryFactory, "create_api_workflow_node_execution_repository", @@ -574,13 +570,18 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte q_empty.limit.return_value = q_empty q_empty.all.return_value = [] empty_session.query.return_value = q_empty - empty_session.commit.return_value = None session_wrappers = [ - _session_wrapper_for_no_autoflush(empty_session), - _session_wrapper_for_no_autoflush(empty_session), - _session_wrapper_for_no_autoflush(empty_session), + _sessionmaker_wrapper_for_begin(empty_session), + _sessionmaker_wrapper_for_begin(empty_session), + _sessionmaker_wrapper_for_begin(empty_session), ] - monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_sessionmaker(*args, **kwargs): + if kwargs.get("autoflush") is False: + return session_wrappers.pop(0) + return object() + + monkeypatch.setattr(service_module, "sessionmaker", fake_sessionmaker) def fake_select(*_args, **_kwargs): stmt = MagicMock() @@ -606,8 +607,6 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte [], ] run_repo.delete_runs_by_ids.return_value = 2 - - monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) monkeypatch.setattr( service_module.DifyAPIRepositoryFactory, "create_api_workflow_node_execution_repository",