mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 11:56:55 +08:00
refactor: migrate session.query to select API in plugin services (#34817)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d360929af1
commit
ee789db443
@ -1,3 +1,4 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -8,10 +9,10 @@ class PluginAutoUpgradeService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
return (
|
return session.scalar(
|
||||||
session.query(TenantPluginAutoUpgradeStrategy)
|
select(TenantPluginAutoUpgradeStrategy)
|
||||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -24,10 +25,10 @@ class PluginAutoUpgradeService:
|
|||||||
include_plugins: list[str],
|
include_plugins: list[str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
exist_strategy = (
|
exist_strategy = session.scalar(
|
||||||
session.query(TenantPluginAutoUpgradeStrategy)
|
select(TenantPluginAutoUpgradeStrategy)
|
||||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not exist_strategy:
|
if not exist_strategy:
|
||||||
strategy = TenantPluginAutoUpgradeStrategy(
|
strategy = TenantPluginAutoUpgradeStrategy(
|
||||||
@ -51,10 +52,10 @@ class PluginAutoUpgradeService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
exist_strategy = (
|
exist_strategy = session.scalar(
|
||||||
session.query(TenantPluginAutoUpgradeStrategy)
|
select(TenantPluginAutoUpgradeStrategy)
|
||||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not exist_strategy:
|
if not exist_strategy:
|
||||||
# create for this tenant
|
# create for this tenant
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -8,7 +9,9 @@ class PluginPermissionService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
return session.scalar(
|
||||||
|
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def change_permission(
|
def change_permission(
|
||||||
@ -17,8 +20,8 @@ class PluginPermissionService:
|
|||||||
debug_permission: TenantPluginPermission.DebugPermission,
|
debug_permission: TenantPluginPermission.DebugPermission,
|
||||||
):
|
):
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
permission = (
|
permission = session.scalar(
|
||||||
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
|
select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1)
|
||||||
)
|
)
|
||||||
if not permission:
|
if not permission:
|
||||||
permission = TenantPluginPermission(
|
permission = TenantPluginPermission(
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class TestGetStrategy:
|
|||||||
def test_returns_strategy_when_found(self):
|
def test_returns_strategy_when_found(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
strategy = MagicMock()
|
strategy = MagicMock()
|
||||||
session.query.return_value.where.return_value.first.return_value = strategy
|
session.scalar.return_value = strategy
|
||||||
|
|
||||||
with p1, p2:
|
with p1, p2:
|
||||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||||
@ -31,7 +31,7 @@ class TestGetStrategy:
|
|||||||
|
|
||||||
def test_returns_none_when_not_found(self):
|
def test_returns_none_when_not_found(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with p1, p2:
|
with p1, p2:
|
||||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||||
@ -44,9 +44,9 @@ class TestGetStrategy:
|
|||||||
class TestChangeStrategy:
|
class TestChangeStrategy:
|
||||||
def test_creates_new_strategy(self):
|
def test_creates_new_strategy(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||||
strat_cls.return_value = MagicMock()
|
strat_cls.return_value = MagicMock()
|
||||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class TestChangeStrategy:
|
|||||||
def test_updates_existing_strategy(self):
|
def test_updates_existing_strategy(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
existing = MagicMock()
|
existing = MagicMock()
|
||||||
session.query.return_value.where.return_value.first.return_value = existing
|
session.scalar.return_value = existing
|
||||||
|
|
||||||
with p1, p2:
|
with p1, p2:
|
||||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||||
@ -90,11 +90,12 @@ class TestChangeStrategy:
|
|||||||
class TestExcludePlugin:
|
class TestExcludePlugin:
|
||||||
def test_creates_default_strategy_when_none_exists(self):
|
def test_creates_default_strategy_when_none_exists(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with (
|
with (
|
||||||
p1,
|
p1,
|
||||||
p2,
|
p2,
|
||||||
|
patch(f"{MODULE}.select"),
|
||||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||||
):
|
):
|
||||||
@ -113,9 +114,9 @@ class TestExcludePlugin:
|
|||||||
existing = MagicMock()
|
existing = MagicMock()
|
||||||
existing.upgrade_mode = "exclude"
|
existing.upgrade_mode = "exclude"
|
||||||
existing.exclude_plugins = ["p-existing"]
|
existing.exclude_plugins = ["p-existing"]
|
||||||
session.query.return_value.where.return_value.first.return_value = existing
|
session.scalar.return_value = existing
|
||||||
|
|
||||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||||
strat_cls.UpgradeMode.ALL = "all"
|
strat_cls.UpgradeMode.ALL = "all"
|
||||||
@ -131,9 +132,9 @@ class TestExcludePlugin:
|
|||||||
existing = MagicMock()
|
existing = MagicMock()
|
||||||
existing.upgrade_mode = "partial"
|
existing.upgrade_mode = "partial"
|
||||||
existing.include_plugins = ["p1", "p2"]
|
existing.include_plugins = ["p1", "p2"]
|
||||||
session.query.return_value.where.return_value.first.return_value = existing
|
session.scalar.return_value = existing
|
||||||
|
|
||||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||||
strat_cls.UpgradeMode.ALL = "all"
|
strat_cls.UpgradeMode.ALL = "all"
|
||||||
@ -148,9 +149,9 @@ class TestExcludePlugin:
|
|||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
existing = MagicMock()
|
existing = MagicMock()
|
||||||
existing.upgrade_mode = "all"
|
existing.upgrade_mode = "all"
|
||||||
session.query.return_value.where.return_value.first.return_value = existing
|
session.scalar.return_value = existing
|
||||||
|
|
||||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||||
strat_cls.UpgradeMode.ALL = "all"
|
strat_cls.UpgradeMode.ALL = "all"
|
||||||
@ -167,9 +168,9 @@ class TestExcludePlugin:
|
|||||||
existing = MagicMock()
|
existing = MagicMock()
|
||||||
existing.upgrade_mode = "exclude"
|
existing.upgrade_mode = "exclude"
|
||||||
existing.exclude_plugins = ["p1"]
|
existing.exclude_plugins = ["p1"]
|
||||||
session.query.return_value.where.return_value.first.return_value = existing
|
session.scalar.return_value = existing
|
||||||
|
|
||||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||||
strat_cls.UpgradeMode.ALL = "all"
|
strat_cls.UpgradeMode.ALL = "all"
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class TestGetPermission:
|
|||||||
def test_returns_permission_when_found(self):
|
def test_returns_permission_when_found(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
permission = MagicMock()
|
permission = MagicMock()
|
||||||
session.query.return_value.where.return_value.first.return_value = permission
|
session.scalar.return_value = permission
|
||||||
|
|
||||||
with p1, p2:
|
with p1, p2:
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
@ -31,7 +31,7 @@ class TestGetPermission:
|
|||||||
|
|
||||||
def test_returns_none_when_not_found(self):
|
def test_returns_none_when_not_found(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with p1, p2:
|
with p1, p2:
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
@ -44,9 +44,9 @@ class TestGetPermission:
|
|||||||
class TestChangePermission:
|
class TestChangePermission:
|
||||||
def test_creates_new_permission_when_not_exists(self):
|
def test_creates_new_permission_when_not_exists(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||||
perm_cls.return_value = MagicMock()
|
perm_cls.return_value = MagicMock()
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ class TestChangePermission:
|
|||||||
def test_updates_existing_permission(self):
|
def test_updates_existing_permission(self):
|
||||||
p1, p2, session = _patched_session()
|
p1, p2, session = _patched_session()
|
||||||
existing = MagicMock()
|
existing = MagicMock()
|
||||||
session.query.return_value.where.return_value.first.return_value = existing
|
session.scalar.return_value = existing
|
||||||
|
|
||||||
with p1, p2:
|
with p1, p2:
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user