mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 23:38:53 +08:00
refactor(services): migrate builtin_tools_manage_service to SQLAlchemy 2.0 select() API (#34973)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
45561bed9d
commit
4ef67fef3a
@ -4,7 +4,7 @@ from collections.abc import Mapping
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import delete, exists, func, select, update
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -47,11 +47,15 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
tool_provider = ToolProviderID(provider)
|
tool_provider = ToolProviderID(provider)
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
session.query(ToolOAuthTenantClient).filter_by(
|
session.execute(
|
||||||
tenant_id=tenant_id,
|
delete(ToolOAuthTenantClient)
|
||||||
provider=tool_provider.provider_name,
|
.where(
|
||||||
plugin_id=tool_provider.plugin_id,
|
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||||
).delete()
|
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||||
|
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||||
|
)
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -151,13 +155,13 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
# get if the provider exists
|
# get if the provider exists
|
||||||
db_provider = (
|
db_provider = session.scalar(
|
||||||
session.query(BuiltinToolProvider)
|
select(BuiltinToolProvider)
|
||||||
.where(
|
.where(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
BuiltinToolProvider.id == credential_id,
|
BuiltinToolProvider.id == credential_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if db_provider is None:
|
if db_provider is None:
|
||||||
raise ValueError(f"you have not added provider {provider}")
|
raise ValueError(f"you have not added provider {provider}")
|
||||||
@ -228,7 +232,13 @@ class BuiltinToolManageService:
|
|||||||
raise ValueError(f"provider {provider} does not need credentials")
|
raise ValueError(f"provider {provider} does not need credentials")
|
||||||
|
|
||||||
provider_count = (
|
provider_count = (
|
||||||
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
session.scalar(
|
||||||
|
select(func.count(BuiltinToolProvider.id)).where(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# check if the provider count is reached the limit
|
# check if the provider count is reached the limit
|
||||||
@ -304,16 +314,15 @@ class BuiltinToolManageService:
|
|||||||
def generate_builtin_tool_provider_name(
|
def generate_builtin_tool_provider_name(
|
||||||
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
||||||
) -> str:
|
) -> str:
|
||||||
db_providers = (
|
db_providers = session.scalars(
|
||||||
session.query(BuiltinToolProvider)
|
select(BuiltinToolProvider)
|
||||||
.filter_by(
|
.where(
|
||||||
tenant_id=tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
provider=provider,
|
BuiltinToolProvider.provider == provider,
|
||||||
credential_type=credential_type,
|
BuiltinToolProvider.credential_type == credential_type,
|
||||||
)
|
)
|
||||||
.order_by(BuiltinToolProvider.created_at.desc())
|
.order_by(BuiltinToolProvider.created_at.desc())
|
||||||
.all()
|
).all()
|
||||||
)
|
|
||||||
return generate_incremental_name(
|
return generate_incremental_name(
|
||||||
[provider.name for provider in db_providers],
|
[provider.name for provider in db_providers],
|
||||||
f"{credential_type.get_name()}",
|
f"{credential_type.get_name()}",
|
||||||
@ -375,13 +384,13 @@ class BuiltinToolManageService:
|
|||||||
delete tool provider
|
delete tool provider
|
||||||
"""
|
"""
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
db_provider = (
|
db_provider = session.scalar(
|
||||||
session.query(BuiltinToolProvider)
|
select(BuiltinToolProvider)
|
||||||
.where(
|
.where(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
BuiltinToolProvider.id == credential_id,
|
BuiltinToolProvider.id == credential_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if db_provider is None:
|
if db_provider is None:
|
||||||
@ -405,14 +414,26 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
# get provider
|
# get provider
|
||||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
target_provider = session.scalar(
|
||||||
|
select(BuiltinToolProvider)
|
||||||
|
.where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
if target_provider is None:
|
if target_provider is None:
|
||||||
raise ValueError("provider not found")
|
raise ValueError("provider not found")
|
||||||
|
|
||||||
# clear default provider
|
# clear default provider
|
||||||
session.query(BuiltinToolProvider).filter_by(
|
session.execute(
|
||||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
update(BuiltinToolProvider)
|
||||||
).update({"is_default": False})
|
.where(
|
||||||
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
|
BuiltinToolProvider.user_id == user_id,
|
||||||
|
BuiltinToolProvider.provider == provider,
|
||||||
|
BuiltinToolProvider.is_default.is_(True),
|
||||||
|
)
|
||||||
|
.values(is_default=False)
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
|
|
||||||
# set new default provider
|
# set new default provider
|
||||||
target_provider.is_default = True
|
target_provider.is_default = True
|
||||||
@ -426,10 +447,13 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
tool_provider = ToolProviderID(provider_name)
|
tool_provider = ToolProviderID(provider_name)
|
||||||
with Session(db.engine, autoflush=False) as session:
|
with Session(db.engine, autoflush=False) as session:
|
||||||
system_client: ToolOAuthSystemClient | None = (
|
system_client = session.scalar(
|
||||||
session.query(ToolOAuthSystemClient)
|
select(ToolOAuthSystemClient)
|
||||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
.where(
|
||||||
.first()
|
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
|
||||||
|
ToolOAuthSystemClient.provider == tool_provider.provider_name,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
return system_client is not None
|
return system_client is not None
|
||||||
|
|
||||||
@ -440,15 +464,15 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
tool_provider = ToolProviderID(provider)
|
tool_provider = ToolProviderID(provider)
|
||||||
with Session(db.engine, autoflush=False) as session:
|
with Session(db.engine, autoflush=False) as session:
|
||||||
user_client: ToolOAuthTenantClient | None = (
|
user_client = session.scalar(
|
||||||
session.query(ToolOAuthTenantClient)
|
select(ToolOAuthTenantClient)
|
||||||
.filter_by(
|
.where(
|
||||||
tenant_id=tenant_id,
|
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||||
provider=tool_provider.provider_name,
|
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||||
plugin_id=tool_provider.plugin_id,
|
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||||
enabled=True,
|
ToolOAuthTenantClient.enabled.is_(True),
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
return user_client is not None and user_client.enabled
|
return user_client is not None and user_client.enabled
|
||||||
|
|
||||||
@ -465,15 +489,15 @@ class BuiltinToolManageService:
|
|||||||
cache=NoOpProviderCredentialCache(),
|
cache=NoOpProviderCredentialCache(),
|
||||||
)
|
)
|
||||||
with Session(db.engine, autoflush=False) as session:
|
with Session(db.engine, autoflush=False) as session:
|
||||||
user_client: ToolOAuthTenantClient | None = (
|
user_client = session.scalar(
|
||||||
session.query(ToolOAuthTenantClient)
|
select(ToolOAuthTenantClient)
|
||||||
.filter_by(
|
.where(
|
||||||
tenant_id=tenant_id,
|
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||||
provider=tool_provider.provider_name,
|
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||||
plugin_id=tool_provider.plugin_id,
|
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||||
enabled=True,
|
ToolOAuthTenantClient.enabled.is_(True),
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
oauth_params: Mapping[str, Any] | None = None
|
oauth_params: Mapping[str, Any] | None = None
|
||||||
if user_client:
|
if user_client:
|
||||||
@ -487,10 +511,13 @@ class BuiltinToolManageService:
|
|||||||
if not is_verified:
|
if not is_verified:
|
||||||
return oauth_params
|
return oauth_params
|
||||||
|
|
||||||
system_client: ToolOAuthSystemClient | None = (
|
system_client = session.scalar(
|
||||||
session.query(ToolOAuthSystemClient)
|
select(ToolOAuthSystemClient)
|
||||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
.where(
|
||||||
.first()
|
ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id,
|
||||||
|
ToolOAuthSystemClient.provider == tool_provider.provider_name,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
if system_client:
|
if system_client:
|
||||||
try:
|
try:
|
||||||
@ -582,8 +609,8 @@ class BuiltinToolManageService:
|
|||||||
provider_name = provider_id_entity.provider_name
|
provider_name = provider_id_entity.provider_name
|
||||||
|
|
||||||
if provider_id_entity.organization != "langgenius":
|
if provider_id_entity.organization != "langgenius":
|
||||||
provider = (
|
provider = session.scalar(
|
||||||
session.query(BuiltinToolProvider)
|
select(BuiltinToolProvider)
|
||||||
.where(
|
.where(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
BuiltinToolProvider.provider == full_provider_name,
|
BuiltinToolProvider.provider == full_provider_name,
|
||||||
@ -592,11 +619,11 @@ class BuiltinToolManageService:
|
|||||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
provider = (
|
provider = session.scalar(
|
||||||
session.query(BuiltinToolProvider)
|
select(BuiltinToolProvider)
|
||||||
.where(
|
.where(
|
||||||
BuiltinToolProvider.tenant_id == tenant_id,
|
BuiltinToolProvider.tenant_id == tenant_id,
|
||||||
(BuiltinToolProvider.provider == provider_name)
|
(BuiltinToolProvider.provider == provider_name)
|
||||||
@ -606,7 +633,7 @@ class BuiltinToolManageService:
|
|||||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider is None:
|
if provider is None:
|
||||||
@ -616,14 +643,14 @@ class BuiltinToolManageService:
|
|||||||
return provider
|
return provider
|
||||||
except Exception:
|
except Exception:
|
||||||
# it's an old provider without organization
|
# it's an old provider without organization
|
||||||
return (
|
return session.scalar(
|
||||||
session.query(BuiltinToolProvider)
|
select(BuiltinToolProvider)
|
||||||
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||||
.order_by(
|
.order_by(
|
||||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -648,14 +675,14 @@ class BuiltinToolManageService:
|
|||||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||||
|
|
||||||
with sessionmaker(bind=db.engine).begin() as session:
|
with sessionmaker(bind=db.engine).begin() as session:
|
||||||
custom_client_params = (
|
custom_client_params = session.scalar(
|
||||||
session.query(ToolOAuthTenantClient)
|
select(ToolOAuthTenantClient)
|
||||||
.filter_by(
|
.where(
|
||||||
tenant_id=tenant_id,
|
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||||
plugin_id=tool_provider.plugin_id,
|
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||||
provider=tool_provider.provider_name,
|
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# if the record does not exist, create a basic record
|
# if the record does not exist, create a basic record
|
||||||
@ -692,14 +719,14 @@ class BuiltinToolManageService:
|
|||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
tool_provider = ToolProviderID(provider)
|
tool_provider = ToolProviderID(provider)
|
||||||
custom_oauth_client_params: ToolOAuthTenantClient | None = (
|
custom_oauth_client_params = session.scalar(
|
||||||
session.query(ToolOAuthTenantClient)
|
select(ToolOAuthTenantClient)
|
||||||
.filter_by(
|
.where(
|
||||||
tenant_id=tenant_id,
|
ToolOAuthTenantClient.tenant_id == tenant_id,
|
||||||
plugin_id=tool_provider.plugin_id,
|
ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id,
|
||||||
provider=tool_provider.provider_name,
|
ToolOAuthTenantClient.provider == tool_provider.provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if custom_oauth_client_params is None:
|
if custom_oauth_client_params is None:
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class TestDeleteCustomOauthClientParams:
|
|||||||
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google")
|
||||||
|
|
||||||
assert result == {"result": "success"}
|
assert result == {"result": "success"}
|
||||||
session.query.return_value.filter_by.return_value.delete.assert_called_once()
|
session.execute.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestListBuiltinToolProviderTools:
|
class TestListBuiltinToolProviderTools:
|
||||||
@ -111,7 +111,7 @@ class TestIsOauthSystemClientExists:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_true_when_exists(self, mock_db, mock_session_cls):
|
def test_true_when_exists(self, mock_db, mock_session_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_session(mock_session_cls)
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock()
|
session.scalar.return_value = MagicMock()
|
||||||
|
|
||||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
|
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ class TestIsOauthSystemClientExists:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_false_when_missing(self, mock_db, mock_session_cls):
|
def test_false_when_missing(self, mock_db, mock_session_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_session(mock_session_cls)
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
|
assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ class TestIsOauthCustomClientEnabled:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_true_when_enabled(self, mock_db, mock_session_cls):
|
def test_true_when_enabled(self, mock_db, mock_session_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_session(mock_session_cls)
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True)
|
session.scalar.return_value = MagicMock(enabled=True)
|
||||||
|
|
||||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
|
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True
|
||||||
|
|
||||||
@ -137,7 +137,7 @@ class TestIsOauthCustomClientEnabled:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_false_when_none(self, mock_db, mock_session_cls):
|
def test_false_when_none(self, mock_db, mock_session_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_session(mock_session_cls)
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
|
assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False
|
||||||
|
|
||||||
@ -149,7 +149,7 @@ class TestDeleteBuiltinToolProvider:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||||
session = _mock_sessionmaker(mock_sm_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="you have not added provider"):
|
with pytest.raises(ValueError, match="you have not added provider"):
|
||||||
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
|
BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id")
|
||||||
@ -161,7 +161,7 @@ class TestDeleteBuiltinToolProvider:
|
|||||||
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc):
|
||||||
session = _mock_sessionmaker(mock_sm_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
db_provider = MagicMock()
|
db_provider = MagicMock()
|
||||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
session.scalar.return_value = db_provider
|
||||||
mock_cache = MagicMock()
|
mock_cache = MagicMock()
|
||||||
mock_enc.return_value = (MagicMock(), mock_cache)
|
mock_enc.return_value = (MagicMock(), mock_cache)
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ class TestSetDefaultProvider:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
def test_raises_when_not_found(self, mock_db, mock_sm_cls):
|
||||||
session = _mock_sessionmaker(mock_sm_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="provider not found"):
|
with pytest.raises(ValueError, match="provider not found"):
|
||||||
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||||
@ -187,7 +187,7 @@ class TestSetDefaultProvider:
|
|||||||
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls):
|
||||||
session = _mock_sessionmaker(mock_sm_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
target = MagicMock()
|
target = MagicMock()
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = target
|
session.scalar.return_value = target
|
||||||
|
|
||||||
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id")
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ class TestUpdateBuiltinToolProvider:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls):
|
||||||
session = _mock_sessionmaker(mock_sm_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
session.query.return_value.where.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="you have not added provider"):
|
with pytest.raises(ValueError, match="you have not added provider"):
|
||||||
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
|
BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c")
|
||||||
@ -213,7 +213,7 @@ class TestUpdateBuiltinToolProvider:
|
|||||||
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc):
|
||||||
session = _mock_sessionmaker(mock_sm_cls)
|
session = _mock_sessionmaker(mock_sm_cls)
|
||||||
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
db_provider = MagicMock(credential_type="api_key", credentials="{}")
|
||||||
session.query.return_value.where.return_value.first.return_value = db_provider
|
session.scalar.return_value = db_provider
|
||||||
|
|
||||||
mock_cred_instance = MagicMock()
|
mock_cred_instance = MagicMock()
|
||||||
mock_cred_instance.is_editable.return_value = True
|
mock_cred_instance.is_editable.return_value = True
|
||||||
@ -274,7 +274,7 @@ class TestGetOauthClient:
|
|||||||
mock_create_enc.return_value = (mock_encrypter, MagicMock())
|
mock_create_enc.return_value = (mock_encrypter, MagicMock())
|
||||||
|
|
||||||
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
|
user_client = MagicMock(oauth_params='{"encrypted": "data"}')
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = user_client
|
session.scalar.return_value = user_client
|
||||||
|
|
||||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||||
|
|
||||||
@ -297,10 +297,7 @@ class TestGetOauthClient:
|
|||||||
mock_create_enc.return_value = (MagicMock(), MagicMock())
|
mock_create_enc.return_value = (MagicMock(), MagicMock())
|
||||||
|
|
||||||
system_client = MagicMock(encrypted_oauth_params="enc")
|
system_client = MagicMock(encrypted_oauth_params="enc")
|
||||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
session.scalar.side_effect = [None, system_client]
|
||||||
None, # user client
|
|
||||||
system_client, # system client
|
|
||||||
]
|
|
||||||
|
|
||||||
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
result = BuiltinToolManageService.get_oauth_client("t", "google")
|
||||||
|
|
||||||
@ -325,7 +322,7 @@ class TestGetCustomOauthClientParams:
|
|||||||
@patch(f"{MODULE}.db")
|
@patch(f"{MODULE}.db")
|
||||||
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
|
def test_returns_empty_when_none(self, mock_db, mock_session_cls):
|
||||||
session = _mock_session(mock_session_cls)
|
session = _mock_session(mock_session_cls)
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
|
result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p")
|
||||||
|
|
||||||
@ -391,7 +388,7 @@ class TestGetBuiltinProvider:
|
|||||||
session = _mock_session(mock_session_cls)
|
session = _mock_session(mock_session_cls)
|
||||||
mock_prov_id.return_value.provider_name = "google"
|
mock_prov_id.return_value.provider_name = "google"
|
||||||
mock_prov_id.return_value.organization = "langgenius"
|
mock_prov_id.return_value.organization = "langgenius"
|
||||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
session.scalar.return_value = None
|
||||||
|
|
||||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||||
|
|
||||||
@ -417,7 +414,7 @@ class TestGetBuiltinProvider:
|
|||||||
return m
|
return m
|
||||||
|
|
||||||
mock_prov_id.side_effect = prov_id_side_effect
|
mock_prov_id.side_effect = prov_id_side_effect
|
||||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
session.scalar.return_value = db_provider
|
||||||
|
|
||||||
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
result = BuiltinToolManageService.get_builtin_provider("google", "t")
|
||||||
|
|
||||||
@ -439,7 +436,7 @@ class TestGetBuiltinProvider:
|
|||||||
|
|
||||||
mock_prov_id.side_effect = prov_id_side_effect
|
mock_prov_id.side_effect = prov_id_side_effect
|
||||||
db_provider = MagicMock(provider="third-party/custom/custom-tool")
|
db_provider = MagicMock(provider="third-party/custom/custom-tool")
|
||||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider
|
session.scalar.return_value = db_provider
|
||||||
|
|
||||||
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
|
result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t")
|
||||||
|
|
||||||
@ -452,7 +449,7 @@ class TestGetBuiltinProvider:
|
|||||||
session = _mock_session(mock_session_cls)
|
session = _mock_session(mock_session_cls)
|
||||||
mock_prov_id.side_effect = Exception("parse error")
|
mock_prov_id.side_effect = Exception("parse error")
|
||||||
fallback = MagicMock()
|
fallback = MagicMock()
|
||||||
session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback
|
session.scalar.return_value = fallback
|
||||||
|
|
||||||
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")
|
result = BuiltinToolManageService.get_builtin_provider("old-provider", "t")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user