mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
refactor(api): use sessionmaker in plugin & trigger services (#34764)
This commit is contained in:
parent
02a9f0abca
commit
1d971d3240
@ -1,4 +1,4 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
@ -7,7 +7,7 @@ from models.account import TenantPluginAutoUpgradeStrategy
|
||||
class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
@ -23,7 +23,7 @@ class PluginAutoUpgradeService:
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
@ -46,12 +46,11 @@ class PluginAutoUpgradeService:
|
||||
exist_strategy.exclude_plugins = exclude_plugins
|
||||
exist_strategy.include_plugins = include_plugins
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
@ -83,5 +82,4 @@ class PluginAutoUpgradeService:
|
||||
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
exist_strategy.exclude_plugins = [plugin_id]
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
@ -7,7 +7,7 @@ from models.account import TenantPluginPermission
|
||||
class PluginPermissionService:
|
||||
@staticmethod
|
||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
|
||||
@staticmethod
|
||||
@ -16,7 +16,7 @@ class PluginPermissionService:
|
||||
install_permission: TenantPluginPermission.InstallPermission,
|
||||
debug_permission: TenantPluginPermission.DebugPermission,
|
||||
):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
||||
)
|
||||
@ -30,5 +30,4 @@ class PluginPermissionService:
|
||||
permission.install_permission = install_permission
|
||||
permission.debug_permission = debug_permission
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@ -8,7 +8,7 @@ This service centralizes all AppTrigger-related business logic.
|
||||
import logging
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import AppTriggerStatus
|
||||
@ -34,13 +34,12 @@ class AppTriggerService:
|
||||
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
session.execute(
|
||||
update(AppTrigger)
|
||||
.where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED)
|
||||
.values(status=AppTriggerStatus.RATE_LIMITED)
|
||||
)
|
||||
session.commit()
|
||||
logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id)
|
||||
|
||||
@ -8,7 +8,7 @@ from flask import Request, Response
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse
|
||||
@ -215,7 +215,7 @@ class TriggerService:
|
||||
not_found_in_cache.append(node_info)
|
||||
continue
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
try:
|
||||
# lock the concurrent plugin trigger creation
|
||||
redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10)
|
||||
@ -260,7 +260,6 @@ class TriggerService:
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Update existing records if subscription_id changed
|
||||
for node_info in nodes_in_graph:
|
||||
@ -290,14 +289,12 @@ class TriggerService:
|
||||
cache.model_dump_json(),
|
||||
ex=60 * 60,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync plugin trigger relationships for app %s", app.id)
|
||||
raise
|
||||
|
||||
@ -12,7 +12,7 @@ from graphon.file import FileTransferMethod
|
||||
from graphon.variables.types import ArrayValidation, SegmentType
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import RequestEntityTooLarge
|
||||
|
||||
@ -912,7 +912,7 @@ class WebhookService:
|
||||
logger.warning("Failed to acquire lock for webhook sync, app %s", app.id)
|
||||
raise RuntimeError("Failed to acquire lock for webhook trigger synchronization")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
# fetch the non-cached nodes from DB
|
||||
all_records = session.scalars(
|
||||
select(WorkflowWebhookTrigger).where(
|
||||
@ -941,14 +941,12 @@ class WebhookService:
|
||||
redis_client.set(
|
||||
f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# delete the nodes not found in the graph
|
||||
for node_id in nodes_id_in_db:
|
||||
if node_id not in nodes_id_in_graph:
|
||||
session.delete(nodes_id_in_db[node_id])
|
||||
redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}")
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to sync webhook relationships for app %s", app.id)
|
||||
raise
|
||||
|
||||
@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_auto_upgrade_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
@ -61,7 +61,6 @@ class TestChangeStrategy:
|
||||
|
||||
assert result is True
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
@ -86,7 +85,6 @@ class TestChangeStrategy:
|
||||
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
assert existing.include_plugins == ["p2"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestExcludePlugin:
|
||||
@ -127,7 +125,6 @@ class TestExcludePlugin:
|
||||
|
||||
assert result is True
|
||||
assert existing.exclude_plugins == ["p-existing", "p-new"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_removes_from_include_list_in_partial_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
|
||||
@ -6,12 +6,12 @@ MODULE = "services.plugin.plugin_permission_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
@ -55,7 +55,6 @@ class TestChangePermission:
|
||||
)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
@ -71,5 +70,4 @@ class TestChangePermission:
|
||||
|
||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
session.commit.assert_called_once()
|
||||
session.add.assert_not_called()
|
||||
|
||||
@ -617,6 +617,20 @@ class _SessionContext:
|
||||
return False
|
||||
|
||||
|
||||
class _SessionmakerContext:
|
||||
def __init__(self, session: Any) -> None:
|
||||
self._session = session
|
||||
|
||||
def begin(self) -> "_SessionmakerContext":
|
||||
return self
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app() -> Flask:
|
||||
return Flask(__name__)
|
||||
@ -625,6 +639,7 @@ def flask_app() -> Flask:
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, session: Any) -> None:
|
||||
monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=MagicMock(), session=MagicMock()))
|
||||
monkeypatch.setattr(service_module, "Session", lambda *args, **kwargs: _SessionContext(session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda *args, **kwargs: _SessionmakerContext(session))
|
||||
|
||||
|
||||
def _workflow_trigger(**kwargs: Any) -> WorkflowWebhookTrigger:
|
||||
@ -1241,7 +1256,6 @@ def test_sync_webhook_relationships_should_create_missing_records_and_delete_sta
|
||||
# Assert
|
||||
assert len(fake_session.added) == 1
|
||||
assert len(fake_session.deleted) == 1
|
||||
assert fake_session.commit_count == 2
|
||||
redis_set_mock.assert_called_once()
|
||||
redis_delete_mock.assert_called_once()
|
||||
lock.release.assert_called_once()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user