mirror of
https://github.com/langgenius/dify.git
synced 2026-04-18 04:16:28 +08:00
refactor(api): use sessionmaker in builtin tools manage service (#34812)
This commit is contained in:
parent
9a51c2f56a
commit
66e588c8ca
@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||||
@ -46,13 +46,12 @@ class BuiltinToolManageService:
|
|||||||
delete custom oauth client params
|
delete custom oauth client params
|
||||||
"""
|
"""
|
||||||
tool_provider = ToolProviderID(provider)
|
tool_provider = ToolProviderID(provider)
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
session.query(ToolOAuthTenantClient).filter_by(
|
session.query(ToolOAuthTenantClient).filter_by(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=tool_provider.provider_name,
|
provider=tool_provider.provider_name,
|
||||||
plugin_id=tool_provider.plugin_id,
|
plugin_id=tool_provider.plugin_id,
|
||||||
).delete()
|
).delete()
|
||||||
session.commit()
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -150,7 +149,7 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
update builtin tool provider
|
update builtin tool provider
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
# get if the provider exists
|
# get if the provider exists
|
||||||
db_provider = (
|
db_provider = (
|
||||||
session.query(BuiltinToolProvider)
|
session.query(BuiltinToolProvider)
|
||||||
@ -203,9 +202,7 @@ class BuiltinToolManageService:
|
|||||||
|
|
||||||
db_provider.name = name
|
db_provider.name = name
|
||||||
|
|
||||||
session.commit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.rollback()
|
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@ -222,7 +219,7 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
add builtin tool provider
|
add builtin tool provider
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
try:
|
try:
|
||||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||||
with redis_client.lock(lock, timeout=20):
|
with redis_client.lock(lock, timeout=20):
|
||||||
@ -281,9 +278,7 @@ class BuiltinToolManageService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
session.add(db_provider)
|
session.add(db_provider)
|
||||||
session.commit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.rollback()
|
|
||||||
raise ValueError(str(e))
|
raise ValueError(str(e))
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
@ -379,7 +374,7 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
delete tool provider
|
delete tool provider
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
db_provider = (
|
db_provider = (
|
||||||
session.query(BuiltinToolProvider)
|
session.query(BuiltinToolProvider)
|
||||||
.where(
|
.where(
|
||||||
@ -393,7 +388,6 @@ class BuiltinToolManageService:
|
|||||||
raise ValueError(f"you have not added provider {provider}")
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
|
|
||||||
session.delete(db_provider)
|
session.delete(db_provider)
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# delete cache
|
# delete cache
|
||||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||||
@ -409,7 +403,7 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
set default provider
|
set default provider
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
# get provider
|
# get provider
|
||||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||||
if target_provider is None:
|
if target_provider is None:
|
||||||
@ -422,7 +416,6 @@ class BuiltinToolManageService:
|
|||||||
|
|
||||||
# set new default provider
|
# set new default provider
|
||||||
target_provider.is_default = True
|
target_provider.is_default = True
|
||||||
session.commit()
|
|
||||||
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@ -654,7 +647,7 @@ class BuiltinToolManageService:
|
|||||||
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
||||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
custom_client_params = (
|
custom_client_params = (
|
||||||
session.query(ToolOAuthTenantClient)
|
session.query(ToolOAuthTenantClient)
|
||||||
.filter_by(
|
.filter_by(
|
||||||
@ -690,7 +683,6 @@ class BuiltinToolManageService:
|
|||||||
if enable_oauth_custom_client is not None:
|
if enable_oauth_custom_client is not None:
|
||||||
custom_client_params.enabled = enable_oauth_custom_client
|
custom_client_params.enabled = enable_oauth_custom_client
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -15,17 +15,24 @@ def _mock_session(mock_session_cls):
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_sessionmaker(mock_sm_cls):
|
||||||
|
"""Helper: set up a sessionmaker().begin() context manager mock and return the inner session."""
|
||||||
|
session = MagicMock()
|
||||||
|
mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||||
|
mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
class TestDeleteCustomOauthClientParams:
|
class TestDeleteCustomOauthClientParams:
|
||||||
@patch(f"{MODULE}.Session")
|
@patch(f"{MODULE}.sessionmaker")
|
||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_deletes_and_returns_success(self, mock_db, mock_session_cls):
|
def test_deletes_and_returns_success(self, mock_db, mock_sm_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
|
|
||||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||||
|
|
||||||
assert result == {"result": "success"}
|
assert result == {"result": "success"}
|
||||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
||||||
session.commit.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestListBuiltinToolProviderTools:
|
class TestListBuiltinToolProviderTools:
|
||||||
@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled:
|
|||||||
class TestDeleteBuiltinToolProvider:
|
class TestDeleteBuiltinToolProvider:
|
||||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||||
@patch(f"{MODULE}.ToolManager")
|
@patch(f"{MODULE}.ToolManager")
|
||||||
@patch(f"{MODULE}.Session")
|
@patch(f"{MODULE}.sessionmaker")
|
||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.query.return_value.where.return_value.first.return_value = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="you have not added provider"):
|
with pytest.raises(ValueError, match="you have not added provider"):
|
||||||
@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider:
|
|||||||
|
|
||||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||||
@patch(f"{MODULE}.ToolManager")
|
@patch(f"{MODULE}.ToolManager")
|
||||||
@patch(f"{MODULE}.Session")
|
@patch(f"{MODULE}.sessionmaker")
|
||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc):
|
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
db_provider = MagicMock()
|
db_provider = MagicMock()
|
||||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||||
mock_cache = MagicMock()
|
mock_cache = MagicMock()
|
||||||
@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider:
|
|||||||
|
|
||||||
assert result == {"result": "success"}
|
assert result == {"result": "success"}
|
||||||
session.delete.assert_called_once_with(db_provider)
|
session.delete.assert_called_once_with(db_provider)
|
||||||
session.commit.assert_called_once()
|
|
||||||
mock_cache.delete.assert_called_once()
|
mock_cache.delete.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestSetDefaultProvider:
|
class TestSetDefaultProvider:
|
||||||
@patch(f"{MODULE}.Session")
|
@patch(f"{MODULE}.sessionmaker")
|
||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_raises_when_not_found(self, mock_db, mock_session_cls):
|
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="provider not found"):
|
with pytest.raises(ValueError, match="provider not found"):
|
||||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||||
|
|
||||||
@patch(f"{MODULE}.Session")
|
@patch(f"{MODULE}.sessionmaker")
|
||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls):
|
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
target = MagicMock()
|
target = MagicMock()
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
session.query.return_value.filter_by.return_value.first.return_value = target
|
||||||
|
|
||||||
@ -187,14 +193,13 @@ class TestSetDefaultProvider:
|
|||||||
|
|
||||||
assert result == {"result": "success"}
|
assert result == {"result": "success"}
|
||||||
assert target.is_default is True
|
assert target.is_default is True
|
||||||
session.commit.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestUpdateBuiltinToolProvider:
|
class TestUpdateBuiltinToolProvider:
|
||||||
@patch(f"{MODULE}.Session")
|
@patch(f"{MODULE}.sessionmaker")
|
||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls):
|
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.query.return_value.where.return_value.first.return_value = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="you have not added provider"):
|
with pytest.raises(ValueError, match="you have not added provider"):
|
||||||
@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider:
|
|||||||
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
|
||||||
@patch(f"{MODULE}.CredentialType")
|
@patch(f"{MODULE}.CredentialType")
|
||||||
@patch(f"{MODULE}.ToolManager")
|
@patch(f"{MODULE}.ToolManager")
|
||||||
@patch(f"{MODULE}.Session")
|
@patch(f"{MODULE}.sessionmaker")
|
||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc):
|
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
session.query.return_value.where.return_value.first.return_value = db_provider
|
||||||
|
|
||||||
@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider:
|
|||||||
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
|
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
|
||||||
|
|
||||||
assert result == {"result": "success"}
|
assert result == {"result": "success"}
|
||||||
session.commit.assert_called_once()
|
|
||||||
mock_cache.delete.assert_called_once()
|
mock_cache.delete.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user