refactor(api): use sessionmaker in builtin tools manage service (#34812)

This commit is contained in:
carlos4s 2026-04-09 00:58:38 -05:00 committed by GitHub
parent 9a51c2f56a
commit 66e588c8ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 40 deletions

View File

@ -5,7 +5,7 @@ from pathlib import Path
from typing import Any from typing import Any
from sqlalchemy import exists, select from sqlalchemy import exists, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE from constants import HIDDEN_VALUE, UNKNOWN_VALUE
@ -46,13 +46,12 @@ class BuiltinToolManageService:
delete custom oauth client params delete custom oauth client params
""" """
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
session.query(ToolOAuthTenantClient).filter_by( session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=tool_provider.provider_name, provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id, plugin_id=tool_provider.plugin_id,
).delete() ).delete()
session.commit()
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
@ -150,7 +149,7 @@ class BuiltinToolManageService:
""" """
update builtin tool provider update builtin tool provider
""" """
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
# get if the provider exists # get if the provider exists
db_provider = ( db_provider = (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
@ -203,9 +202,7 @@ class BuiltinToolManageService:
db_provider.name = name db_provider.name = name
session.commit()
except Exception as e: except Exception as e:
session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@ -222,7 +219,7 @@ class BuiltinToolManageService:
""" """
add builtin tool provider add builtin tool provider
""" """
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
try: try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20): with redis_client.lock(lock, timeout=20):
@ -281,9 +278,7 @@ class BuiltinToolManageService:
) )
session.add(db_provider) session.add(db_provider)
session.commit()
except Exception as e: except Exception as e:
session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@ -379,7 +374,7 @@ class BuiltinToolManageService:
""" """
delete tool provider delete tool provider
""" """
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
db_provider = ( db_provider = (
session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.where( .where(
@ -393,7 +388,6 @@ class BuiltinToolManageService:
raise ValueError(f"you have not added provider {provider}") raise ValueError(f"you have not added provider {provider}")
session.delete(db_provider) session.delete(db_provider)
session.commit()
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -409,7 +403,7 @@ class BuiltinToolManageService:
""" """
set default provider set default provider
""" """
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
# get provider # get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first() target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
if target_provider is None: if target_provider is None:
@ -422,7 +416,6 @@ class BuiltinToolManageService:
# set new default provider # set new default provider
target_provider.is_default = True target_provider.is_default = True
session.commit()
return {"result": "success"} return {"result": "success"}
@ -654,7 +647,7 @@ class BuiltinToolManageService:
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider") raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session: with sessionmaker(bind=db.engine).begin() as session:
custom_client_params = ( custom_client_params = (
session.query(ToolOAuthTenantClient) session.query(ToolOAuthTenantClient)
.filter_by( .filter_by(
@ -690,7 +683,6 @@ class BuiltinToolManageService:
if enable_oauth_custom_client is not None: if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod

View File

@ -15,17 +15,24 @@ def _mock_session(mock_session_cls):
return session return session
def _mock_sessionmaker(mock_sm_cls):
"""Helper: set up a sessionmaker().begin() context manager mock and return the inner session."""
session = MagicMock()
mock_sm_cls.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sm_cls.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
return session
class TestDeleteCustomOauthClientParams: class TestDeleteCustomOauthClientParams:
@patch(f"{MODULE}.Session") @patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db") @patch(f"{MODULE}.db")
def test_deletes_and_returns_success(self, mock_db, mock_session_cls): def test_deletes_and_returns_success(self, mock_db, mock_sm_cls):
session = _mock_session(mock_session_cls) session = _mock_sessionmaker(mock_sm_cls)
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google") result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
assert result == {"result": "success"} assert result == {"result": "success"}
session.query.return_value.filter_by.return_value.delete.assert_called_once() session.query.return_value.filter_by.return_value.delete.assert_called_once()
session.commit.assert_called_once()
class TestListBuiltinToolProviderTools: class TestListBuiltinToolProviderTools:
@ -138,10 +145,10 @@ class TestIsOauthCustomClientEnabled:
class TestDeleteBuiltinToolProvider: class TestDeleteBuiltinToolProvider:
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.ToolManager") @patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session") @patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db") @patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc): def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_session(mock_session_cls) session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="you have not added provider"): with pytest.raises(ValueError, match="you have not added provider"):
@ -149,10 +156,10 @@ class TestDeleteBuiltinToolProvider:
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.ToolManager") @patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session") @patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db") @patch(f"{MODULE}.db")
def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc): def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
session = _mock_session(mock_session_cls) session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock() db_provider = MagicMock()
session.query.return_value.where.return_value.first.return_value = db_provider session.query.return_value.where.return_value.first.return_value = db_provider
mock_cache = MagicMock() mock_cache = MagicMock()
@ -162,24 +169,23 @@ class TestDeleteBuiltinToolProvider:
assert result == {"result": "success"} assert result == {"result": "success"}
session.delete.assert_called_once_with(db_provider) session.delete.assert_called_once_with(db_provider)
session.commit.assert_called_once()
mock_cache.delete.assert_called_once() mock_cache.delete.assert_called_once()
class TestSetDefaultProvider: class TestSetDefaultProvider:
@patch(f"{MODULE}.Session") @patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db") @patch(f"{MODULE}.db")
def test_raises_when_not_found(self, mock_db, mock_session_cls): def test_raises_when_not_found(self, mock_db, mock_sm_cls):
session = _mock_session(mock_session_cls) session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.filter_by.return_value.first.return_value = None session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="provider not found"): with pytest.raises(ValueError, match="provider not found"):
BuiltinToolManageService.set_default_provider("t", "u", "p", "id") BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
@patch(f"{MODULE}.Session") @patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db") @patch(f"{MODULE}.db")
def test_sets_default_and_clears_old(self, mock_db, mock_session_cls): def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
session = _mock_session(mock_session_cls) session = _mock_sessionmaker(mock_sm_cls)
target = MagicMock() target = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = target session.query.return_value.filter_by.return_value.first.return_value = target
@ -187,14 +193,13 @@ class TestSetDefaultProvider:
assert result == {"result": "success"} assert result == {"result": "success"}
assert target.is_default is True assert target.is_default is True
session.commit.assert_called_once()
class TestUpdateBuiltinToolProvider: class TestUpdateBuiltinToolProvider:
@patch(f"{MODULE}.Session") @patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db") @patch(f"{MODULE}.db")
def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls): def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
session = _mock_session(mock_session_cls) session = _mock_sessionmaker(mock_sm_cls)
session.query.return_value.where.return_value.first.return_value = None session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError, match="you have not added provider"): with pytest.raises(ValueError, match="you have not added provider"):
@ -203,10 +208,10 @@ class TestUpdateBuiltinToolProvider:
@patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter")
@patch(f"{MODULE}.CredentialType") @patch(f"{MODULE}.CredentialType")
@patch(f"{MODULE}.ToolManager") @patch(f"{MODULE}.ToolManager")
@patch(f"{MODULE}.Session") @patch(f"{MODULE}.sessionmaker")
@patch(f"{MODULE}.db") @patch(f"{MODULE}.db")
def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc): def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
session = _mock_session(mock_session_cls) session = _mock_sessionmaker(mock_sm_cls)
db_provider = MagicMock(credential_type="api_key", credentials="{}") db_provider = MagicMock(credential_type="api_key", credentials="{}")
session.query.return_value.where.return_value.first.return_value = db_provider session.query.return_value.where.return_value.first.return_value = db_provider
@ -227,7 +232,6 @@ class TestUpdateBuiltinToolProvider:
result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"}) result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"})
assert result == {"result": "success"} assert result == {"result": "success"}
session.commit.assert_called_once()
mock_cache.delete.assert_called_once() mock_cache.delete.assert_called_once()