refactor(api): use sessionmaker in plugin & trigger services (#34764)

This commit is contained in:
carlos4s 2026-04-08 18:18:26 -05:00 committed by GitHub
parent 02a9f0abca
commit 1d971d3240
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 38 additions and 38 deletions

View File

@ -1,4 +1,4 @@
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
@ -7,7 +7,7 @@ from models.account import TenantPluginAutoUpgradeStrategy
class PluginAutoUpgradeService:
@staticmethod
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
return (
session.query(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
@ -23,7 +23,7 @@ class PluginAutoUpgradeService:
exclude_plugins: list[str],
include_plugins: list[str],
) -> bool:
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
@ -46,12 +46,11 @@ class PluginAutoUpgradeService:
exist_strategy.exclude_plugins = exclude_plugins
exist_strategy.include_plugins = include_plugins
session.commit()
return True
@staticmethod
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
@ -83,5 +82,4 @@ class PluginAutoUpgradeService:
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
exist_strategy.exclude_plugins = [plugin_id]
session.commit()
return True

View File

@ -1,4 +1,4 @@
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models.account import TenantPluginPermission
@ -7,7 +7,7 @@ from models.account import TenantPluginPermission
class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
@staticmethod
@ -16,7 +16,7 @@ class PluginPermissionService:
install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission,
):
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
permission = (
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
)
@ -30,5 +30,4 @@ class PluginPermissionService:
permission.install_permission = install_permission
permission.debug_permission = debug_permission
session.commit()
return True

View File

@ -8,7 +8,7 @@ This service centralizes all AppTrigger-related business logic.
import logging
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models.enums import AppTriggerStatus
@ -34,13 +34,12 @@ class AppTriggerService:
"""
try:
with Session(db.engine) as session:
with sessionmaker(bind=db.engine).begin() as session:
session.execute(
update(AppTrigger)
.where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED)
.values(status=AppTriggerStatus.RATE_LIMITED)
)
session.commit()
logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id)
except Exception:
logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id)

View File

@ -8,7 +8,7 @@ from flask import Request, Response
from graphon.entities.graph_config import NodeConfigDict
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse
@ -215,7 +215,7 @@ class TriggerService:
not_found_in_cache.append(node_info)
continue
with Session(db.engine) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
try:
# lock the concurrent plugin trigger creation
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
@ -260,7 +260,6 @@ class TriggerService:
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# Update existing records if subscription_id changed
for node_info in nodes_in_graph:
@ -290,14 +289,12 @@ class TriggerService:
cache.model_dump_json(),
ex=60 * 60,
)
session.commit()
# delete the nodes not found in the graph
for node_id in nodes_id_in_db:
if node_id not in nodes_id_in_graph:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
raise

View File

@ -12,7 +12,7 @@ from graphon.file import FileTransferMethod
from graphon.variables.types import ArrayValidation, SegmentType
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import RequestEntityTooLarge
@ -912,7 +912,7 @@ class WebhookService:
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
with Session(db.engine) as session:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
# fetch the non-cached nodes from DB
all_records = session.scalars(
select(WorkflowWebhookTrigger).where(
@ -941,14 +941,12 @@ class WebhookService:
redis_client.set(
f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60
)
session.commit()
# delete the nodes not found in the graph
for node_id in nodes_id_in_db:
if node_id not in nodes_id_in_graph:
session.delete(nodes_id_in_db[node_id])
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
session.commit()
except Exception:
logger.exception("Failed to sync webhook relationships for app %s", app.id)
raise

View File

@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_auto_upgrade_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
@ -61,7 +61,6 @@ class TestChangeStrategy:
assert result is True
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_strategy(self):
p1, p2, session = _patched_session()
@ -86,7 +85,6 @@ class TestChangeStrategy:
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
assert existing.exclude_plugins == ["p1"]
assert existing.include_plugins == ["p2"]
session.commit.assert_called_once()
class TestExcludePlugin:
@ -127,7 +125,6 @@ class TestExcludePlugin:
assert result is True
assert existing.exclude_plugins == ["p-existing", "p-new"]
session.commit.assert_called_once()
def test_removes_from_include_list_in_partial_mode(self):
p1, p2, session = _patched_session()

View File

@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch Session(db.engine) to return a mock session as context manager."""
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
session = MagicMock()
session_cls = MagicMock()
session_cls.return_value.__enter__ = MagicMock(return_value=session)
session_cls.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.Session", session_cls)
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
@ -55,7 +55,6 @@ class TestChangePermission:
)
session.add.assert_called_once()
session.commit.assert_called_once()
def test_updates_existing_permission(self):
p1, p2, session = _patched_session()
@ -71,5 +70,4 @@ class TestChangePermission:
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.commit.assert_called_once()
session.add.assert_not_called()

View File

@ -617,6 +617,20 @@ class _SessionContext:
return False
class _SessionmakerContext:
def __init__(self, session: Any) -> None:
self._session = session
def begin(self) -> "_SessionmakerContext":
return self
def __enter__(self) -> Any:
return self._session
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
return False
@pytest.fixture
def flask_app() -> Flask:
return Flask(__name__)
@ -625,6 +639,7 @@ def flask_app() -> Flask:
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
@ -1241,7 +1256,6 @@ def test_sync_webhook_relationships_should_create_missing_records_and_delete_sta
# Assert
assert len(fake_session.added) == 1
assert len(fake_session.deleted) == 1
assert fake_session.commit_count == 2
redis_set_mock.assert_called_once()
redis_delete_mock.assert_called_once()
lock.release.assert_called_once()