diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index 174bed488d..adbed87c3c 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -1,4 +1,4 @@ -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models.account import TenantPluginAutoUpgradeStrategy @@ -7,7 +7,7 @@ from models.account import TenantPluginAutoUpgradeStrategy class PluginAutoUpgradeService: @staticmethod def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: return ( session.query(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -23,7 +23,7 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: exist_strategy = ( session.query(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -46,12 +46,11 @@ class PluginAutoUpgradeService: exist_strategy.exclude_plugins = exclude_plugins exist_strategy.include_plugins = include_plugins - session.commit() return True @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: exist_strategy = ( session.query(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -83,5 +82,4 @@ class PluginAutoUpgradeService: exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE exist_strategy.exclude_plugins = [plugin_id] - session.commit() return True diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 60fa269640..55276d6f99 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,4 +1,4 @@ -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models.account import TenantPluginPermission @@ -7,7 +7,7 @@ from models.account import TenantPluginPermission class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() @staticmethod @@ -16,7 +16,7 @@ class PluginPermissionService: install_permission: TenantPluginPermission.InstallPermission, debug_permission: TenantPluginPermission.DebugPermission, ): - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: permission = ( session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() ) @@ -30,5 +30,4 @@ class PluginPermissionService: permission.install_permission = install_permission permission.debug_permission = debug_permission - session.commit() return True diff --git a/api/services/trigger/app_trigger_service.py b/api/services/trigger/app_trigger_service.py index 6d5a719f63..723d29e947 100644 --- a/api/services/trigger/app_trigger_service.py +++ b/api/services/trigger/app_trigger_service.py @@ -8,7 +8,7 @@ This service centralizes all AppTrigger-related business logic. import logging from sqlalchemy import update -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models.enums import AppTriggerStatus @@ -34,13 +34,12 @@ class AppTriggerService: """ try: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.execute( update(AppTrigger) .where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED) .values(status=AppTriggerStatus.RATE_LIMITED) ) - session.commit() logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id) except Exception: logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id) diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index d72c041609..5a5d13b96d 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -8,7 +8,7 @@ from flask import Request, Response from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse @@ -215,7 +215,7 @@ class TriggerService: not_found_in_cache.append(node_info) continue - with Session(db.engine) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: try: # lock the concurrent plugin trigger creation redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10) @@ -260,7 +260,6 @@ class TriggerService: cache.model_dump_json(), ex=60 * 60, ) - session.commit() # Update existing records if subscription_id changed for node_info in nodes_in_graph: @@ -290,14 +289,12 @@ class TriggerService: cache.model_dump_json(), ex=60 * 60, ) - session.commit() # delete the nodes not found in the graph for node_id in nodes_id_in_db: if node_id not in nodes_id_in_graph: session.delete(nodes_id_in_db[node_id]) redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}") - session.commit() except Exception: logger.exception("Failed to sync plugin trigger relationships for app %s", app.id) raise diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index f72c69a33e..8e629deb32 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -12,7 +12,7 @@ from graphon.file import FileTransferMethod from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from werkzeug.datastructures import FileStorage from werkzeug.exceptions import RequestEntityTooLarge @@ -912,7 +912,7 @@ class WebhookService: logger.warning("Failed to acquire lock for webhook sync, app %s", app.id) raise RuntimeError("Failed to acquire lock for webhook trigger synchronization") - with Session(db.engine) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # fetch the non-cached nodes from DB all_records = session.scalars( select(WorkflowWebhookTrigger).where( @@ -941,14 +941,12 @@ class WebhookService: redis_client.set( f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60 ) - session.commit() # delete the nodes not found in the graph for node_id in nodes_id_in_db: if node_id not in nodes_id_in_graph: session.delete(nodes_id_in_db[node_id]) redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}") - session.commit() except Exception: logger.exception("Failed to sync webhook relationships for app %s", app.id) raise diff --git a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py index edb50d09a6..45156958b6 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py @@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_auto_upgrade_service" def _patched_session(): - """Patch Session(db.engine) to return a mock session as context manager.""" + """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" session = MagicMock() - session_cls = MagicMock() - session_cls.return_value.__enter__ = MagicMock(return_value=session) - session_cls.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.Session", session_cls) + mock_sessionmaker = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) db_patcher = patch(f"{MODULE}.db") return patcher, db_patcher, session @@ -61,7 +61,6 @@ class TestChangeStrategy: assert result is True session.add.assert_called_once() - session.commit.assert_called_once() def test_updates_existing_strategy(self): p1, p2, session = _patched_session() @@ -86,7 +85,6 @@ class TestChangeStrategy: assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL assert existing.exclude_plugins == ["p1"] assert existing.include_plugins == ["p2"] - session.commit.assert_called_once() class TestExcludePlugin: @@ -127,7 +125,6 @@ class TestExcludePlugin: assert result is True assert existing.exclude_plugins == ["p-existing", "p-new"] - session.commit.assert_called_once() def test_removes_from_include_list_in_partial_mode(self): p1, p2, session = _patched_session() diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py index 69091110db..40f4c6a8d2 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py @@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_permission_service" def _patched_session(): - """Patch Session(db.engine) to return a mock session as context manager.""" + """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" session = MagicMock() - session_cls = MagicMock() - session_cls.return_value.__enter__ = MagicMock(return_value=session) - session_cls.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.Session", session_cls) + mock_sessionmaker = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) db_patcher = patch(f"{MODULE}.db") return patcher, db_patcher, session @@ -55,7 +55,6 @@ class TestChangePermission: ) session.add.assert_called_once() - session.commit.assert_called_once() def test_updates_existing_permission(self): p1, p2, session = _patched_session() @@ -71,5 +70,4 @@ class TestChangePermission: assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS - session.commit.assert_called_once() session.add.assert_not_called() diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 78049182ad..1b5252fc64 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -617,6 +617,20 @@ class _SessionContext: return False +class _SessionmakerContext: + def __init__(self, session: Any) -> None: + self._session = session + + def begin(self) -> "_SessionmakerContext": + return self + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + @pytest.fixture def flask_app() -> Flask: return Flask(__name__) @@ -625,6 +639,7 @@ def flask_app() -> Flask: def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None: monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock())) monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session)) + monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session)) def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger: @@ -1241,7 +1256,6 @@ def test_sync_webhook_relationships_should_create_missing_records_and_delete_sta # Assert assert len(fake_session.added) == 1 assert len(fake_session.deleted) == 1 - assert fake_session.commit_count == 2 redis_set_mock.assert_called_once() redis_delete_mock.assert_called_once() lock.release.assert_called_once()