mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 13:51:05 +08:00
test: unit test case for controllers.console.workspace module (#32181)
This commit is contained in:
parent
8906ab8e52
commit
497feac48e
@ -0,0 +1,341 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailCodeError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountAvatarApi,
|
||||
AccountDeleteApi,
|
||||
AccountDeleteVerifyApi,
|
||||
AccountInitApi,
|
||||
AccountIntegrateApi,
|
||||
AccountInterfaceLanguageApi,
|
||||
AccountInterfaceThemeApi,
|
||||
AccountNameApi,
|
||||
AccountPasswordApi,
|
||||
AccountProfileApi,
|
||||
AccountTimezoneApi,
|
||||
ChangeEmailCheckApi,
|
||||
ChangeEmailResetApi,
|
||||
CheckEmailUnique,
|
||||
)
|
||||
from controllers.console.workspace.error import (
|
||||
AccountAlreadyInitedError,
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidAccountDeletionCodeError,
|
||||
)
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAccountInitApi:
|
||||
def test_init_success(self, app):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="inactive")
|
||||
payload = {
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
"invitation_code": "code123",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
|
||||
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.workspace.account.db.session.query") as query_mock,
|
||||
):
|
||||
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
|
||||
resp = method(api)
|
||||
|
||||
assert resp["result"] == "success"
|
||||
|
||||
def test_init_already_initialized(self, app):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="active")
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
):
|
||||
with pytest.raises(AccountAlreadyInitedError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAccountProfileApi:
|
||||
def test_get_profile_success(self, app):
|
||||
api = AccountProfileApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/profile"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
|
||||
class TestAccountUpdateApis:
|
||||
@pytest.mark.parametrize(
|
||||
("api_cls", "payload"),
|
||||
[
|
||||
(AccountNameApi, {"name": "test"}),
|
||||
(AccountAvatarApi, {"avatar": "img.png"}),
|
||||
(AccountInterfaceLanguageApi, {"interface_language": "en-US"}),
|
||||
(AccountInterfaceThemeApi, {"interface_theme": "dark"}),
|
||||
(AccountTimezoneApi, {"timezone": "UTC"}),
|
||||
],
|
||||
)
|
||||
def test_update_success(self, app, api_cls, payload):
|
||||
api = api_cls()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
|
||||
class TestAccountPasswordApi:
|
||||
def test_password_success(self, app):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "old",
|
||||
"new_password": "new123",
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
def test_password_wrong_current(self, app):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "bad",
|
||||
"new_password": "new123",
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.update_account_password",
|
||||
side_effect=ServicePwdError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CurrentPasswordIncorrectError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAccountIntegrateApi:
|
||||
def test_get_integrates(self, app):
|
||||
api = AccountIntegrateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
account = MagicMock(id="acc1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
|
||||
):
|
||||
scalars_mock.return_value.all.return_value = []
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
|
||||
class TestAccountDeleteApi:
|
||||
def test_delete_verify_success(self, app):
|
||||
api = AccountDeleteVerifyApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
|
||||
return_value=("token", "1234"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.send_account_deletion_verification_email",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_invalid_code(self, app):
|
||||
api = AccountDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"token": "t", "code": "x"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidAccountDeletionCodeError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestChangeEmailApis:
|
||||
def test_check_email_code_invalid(self, app):
|
||||
api = ChangeEmailCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "a@test.com", "code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.get_change_email_data",
|
||||
return_value={"email": "a@test.com", "code": "y"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
|
||||
def test_reset_email_already_used(self, app):
|
||||
api = ChangeEmailResetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"new_email": "x@test.com", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
|
||||
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
|
||||
):
|
||||
with pytest.raises(EmailAlreadyInUseError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCheckEmailUniqueApi:
|
||||
def test_email_unique_success(self, app):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "ok@test.com"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
|
||||
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_email_in_freeze(self, app):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "x@test.com"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True),
|
||||
):
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
method(api)
|
||||
@ -0,0 +1,139 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.error import AccountNotFound
|
||||
from controllers.console.workspace.agent_providers import (
|
||||
AgentProviderApi,
|
||||
AgentProviderListApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAgentProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
providers = [{"name": "openai"}, {"name": "anthropic"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=providers,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == providers
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_account_not_found(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAgentProviderApi:
|
||||
def test_get_success(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
provider_name = "openai"
|
||||
provider_data = {"name": "openai", "models": ["gpt-4"]}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=provider_data,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
|
||||
assert result == provider_data
|
||||
|
||||
def test_get_provider_not_found(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
provider_name = "unknown"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_account_not_found(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api, "openai")
|
||||
@ -0,0 +1,305 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.workspace.endpoint import (
|
||||
EndpointCreateApi,
|
||||
EndpointDeleteApi,
|
||||
EndpointDisableApi,
|
||||
EndpointEnableApi,
|
||||
EndpointListApi,
|
||||
EndpointListForSinglePluginApi,
|
||||
EndpointUpdateApi,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_and_tenant():
|
||||
return MagicMock(id="u1"), "t1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_current_account(user_and_tenant):
|
||||
with patch(
|
||||
"controllers.console.workspace.endpoint.current_account_with_tenant",
|
||||
return_value=user_and_tenant,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"name": "endpoint",
|
||||
"settings": {"a": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_create_permission_denied(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"name": "endpoint",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.create_endpoint",
|
||||
side_effect=PluginPermissionDeniedError("denied"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_create_validation_error(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "p1",
|
||||
"name": "",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListApi:
|
||||
def test_list_success(self, app):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "endpoints" in result
|
||||
assert len(result["endpoints"]) == 1
|
||||
|
||||
def test_list_invalid_query(self, app):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=0&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListForSinglePluginApi:
|
||||
def test_list_for_plugin_success(self, app):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin",
|
||||
return_value=[{"id": "e1"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "endpoints" in result
|
||||
|
||||
def test_list_for_plugin_missing_param(self, app):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_delete_invalid_payload(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_delete_service_failure(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
"name": "new-name",
|
||||
"settings": {"x": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_update_validation_error(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1", "settings": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_update_service_failure(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
"name": "n",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointEnableApi:
|
||||
def test_enable_success(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_enable_invalid_payload(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_enable_service_failure(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDisableApi:
|
||||
def test_disable_success(self, app):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_disable_invalid_payload(self, app):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
@ -0,0 +1,607 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
MemberNotInTenantError,
|
||||
NotOwnerError,
|
||||
OwnerTransferLimitError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
|
||||
from controllers.console.workspace.members import (
|
||||
DatasetOperatorMemberListApi,
|
||||
MemberCancelInviteApi,
|
||||
MemberInviteEmailApi,
|
||||
MemberListApi,
|
||||
MemberUpdateRoleApi,
|
||||
OwnerTransfer,
|
||||
OwnerTransferCheckApi,
|
||||
SendOwnerTransferEmailApi,
|
||||
)
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestMemberListApi:
|
||||
def test_get_success(self, app):
|
||||
api = MemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
member.id = "m1"
|
||||
member.name = "Member"
|
||||
member.email = "member@test.com"
|
||||
member.avatar = "avatar.png"
|
||||
member.role = "admin"
|
||||
member.status = "active"
|
||||
members = [member]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
|
||||
def test_get_no_tenant(self, app):
|
||||
api = MemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(current_tenant=None)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestMemberInviteEmailApi:
|
||||
def test_invite_success(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
"language": "en-US",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_invite_limit_exceeded(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = False
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
with pytest.raises(WorkspaceMembersLimitExceeded):
|
||||
method(api)
|
||||
|
||||
def test_invite_already_member(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch(
|
||||
"controllers.console.workspace.members.RegisterService.invite_new_member",
|
||||
side_effect=AccountAlreadyInTenantError(),
|
||||
),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "success"
|
||||
|
||||
def test_invite_invalid_role(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "owner",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert result["code"] == "invalid-role"
|
||||
|
||||
def test_invite_generic_exception(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch(
|
||||
"controllers.console.workspace.members.RegisterService.invite_new_member",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, _ = method(api)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "failed"
|
||||
|
||||
|
||||
class TestMemberCancelInviteApi:
|
||||
def test_cancel_success(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_cancel_not_found(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "x")
|
||||
|
||||
def test_cancel_cannot_operate_self(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_cancel_no_permission(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
|
||||
def test_cancel_member_not_in_tenant(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestMemberUpdateRoleApi:
|
||||
def test_update_success(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.update_member_role"),
|
||||
):
|
||||
result = method(api, "id")
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_update_invalid_role(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"role": "invalid-role"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api, "id")
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_update_member_not_found(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.members.current_account_with_tenant",
|
||||
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=None),
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "id")
|
||||
|
||||
|
||||
class TestDatasetOperatorMemberListApi:
|
||||
def test_get_success(self, app):
|
||||
api = DatasetOperatorMemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
member.id = "op1"
|
||||
member.name = "Operator"
|
||||
member.email = "operator@test.com"
|
||||
member.avatar = "avatar.png"
|
||||
member.role = "operator"
|
||||
member.status = "active"
|
||||
members = [member]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
|
||||
def test_get_no_tenant(self, app):
|
||||
api = DatasetOperatorMemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(current_tenant=None)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestSendOwnerTransferEmailApi:
|
||||
def test_send_success(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(name="ws")
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_send_ip_limit(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
|
||||
):
|
||||
with pytest.raises(EmailSendIpLimitError):
|
||||
method(api)
|
||||
|
||||
def test_send_not_owner(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
|
||||
):
|
||||
with pytest.raises(NotOwnerError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestOwnerTransferCheckApi:
|
||||
def test_check_invalid_code(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "a@test.com", "code": "y"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
|
||||
def test_rate_limited(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
with pytest.raises(OwnerTransferLimitError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_token(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_email(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "b@test.com", "code": "x"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidEmailError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestOwnerTransferApi:
|
||||
def test_transfer_self(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
):
|
||||
with pytest.raises(CannotTransferOwnerToSelfError):
|
||||
method(api, "1")
|
||||
|
||||
def test_invalid_token(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api, "2")
|
||||
|
||||
def test_member_not_in_tenant(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "a@test.com"},
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
|
||||
):
|
||||
with pytest.raises(MemberNotInTenantError):
|
||||
method(api, "2")
|
||||
@ -0,0 +1,388 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic_core import ValidationError
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.model_providers import (
|
||||
ModelProviderCredentialApi,
|
||||
ModelProviderCredentialSwitchApi,
|
||||
ModelProviderIconApi,
|
||||
ModelProviderListApi,
|
||||
ModelProviderPaymentCheckoutUrlApi,
|
||||
ModelProviderValidateApi,
|
||||
PreferredProviderTypeUpdateApi,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
INVALID_UUID = "123"
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestModelProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ModelProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?model_type=llm"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
|
||||
return_value=[{"name": "openai"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderCredentialApi:
|
||||
def test_get_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={VALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
|
||||
return_value={"key": "value"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
def test_get_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_post_create_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}, "name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert status == 201
|
||||
|
||||
def test_post_create_validation_error(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_put_update_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_put_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {"credential_id": VALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestModelProviderCredentialSwitchApi:
|
||||
def test_switch_success(self, app):
|
||||
api = ModelProviderCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": VALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_switch_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": INVALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderValidateApi:
|
||||
def test_validate_success(self, app):
|
||||
api = ModelProviderValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_validate_failure(self, app):
|
||||
api = ModelProviderValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
|
||||
class TestModelProviderIconApi:
|
||||
def test_icon_success(self, app):
|
||||
api = ModelProviderIconApi()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
|
||||
return_value=(b"123", "image/png"),
|
||||
),
|
||||
):
|
||||
response = api.get("t1", "openai", "logo", "en")
|
||||
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_icon_not_found(self, app):
|
||||
api = ModelProviderIconApi()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
|
||||
return_value=(None, None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
api.get("t1", "openai", "logo", "en")
|
||||
|
||||
|
||||
class TestPreferredProviderTypeUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = PreferredProviderTypeUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"preferred_provider_type": "custom"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_invalid_enum(self, app):
|
||||
api = PreferredProviderTypeUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"preferred_provider_type": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderPaymentCheckoutUrlApi:
|
||||
def test_checkout_success(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="u1", email="x@test.com")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link",
|
||||
return_value={"url": "x"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="anthropic")
|
||||
|
||||
assert "url" in result
|
||||
|
||||
def test_invalid_provider(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_permission_denied(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="u1", email="x@test.com")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
side_effect=Forbidden(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, provider="anthropic")
|
||||
@ -0,0 +1,447 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.workspace.models import (
|
||||
DefaultModelApi,
|
||||
ModelProviderAvailableModelApi,
|
||||
ModelProviderModelApi,
|
||||
ModelProviderModelCredentialApi,
|
||||
ModelProviderModelCredentialSwitchApi,
|
||||
ModelProviderModelDisableApi,
|
||||
ModelProviderModelEnableApi,
|
||||
ModelProviderModelParameterRuleApi,
|
||||
ModelProviderModelValidateApi,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDefaultModelApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"model_type": ModelType.LLM.value},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
|
||||
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_post_success(self, app: Flask):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model_settings": [
|
||||
{
|
||||
"model_type": ModelType.LLM.value,
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_get_returns_empty_when_no_default(self, app):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_default_model_of_model_type.return_value = None
|
||||
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderModelApi:
|
||||
def test_get_models_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_post_models_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"load_balancing": {
|
||||
"configs": [{"weight": 1}],
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_model_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
def test_get_models_returns_empty(self, app):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderModelCredentialApi:
|
||||
def test_get_credentials_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
|
||||
):
|
||||
provider_service.return_value.get_model_credential.return_value = {
|
||||
"credentials": {},
|
||||
"current_credential_id": None,
|
||||
"current_credential_name": None,
|
||||
}
|
||||
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
def test_create_credential_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_get_empty_credentials(self, app):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
|
||||
):
|
||||
service.return_value.get_model_credential.return_value = None
|
||||
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["credentials"] == {}
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"model": "gpt",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestModelProviderModelCredentialSwitchApi:
|
||||
def test_switch_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credential_id": "abc",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestModelEnableDisableApis:
|
||||
def test_enable_model(self, app: Flask):
|
||||
api = ModelProviderModelEnableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_disable_model(self, app: Flask):
|
||||
api = ModelProviderModelDisableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestModelProviderModelValidateApi:
|
||||
def test_validate_success(self, app: Flask):
|
||||
api = ModelProviderModelValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["gpt-4", "gpt"])
|
||||
def test_validate_failure(self, app: Flask, model_name: str):
|
||||
api = ModelProviderModelValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
|
||||
class TestParameterAndAvailableModels:
|
||||
def test_parameter_rules(self, app: Flask):
|
||||
api = ModelProviderModelParameterRuleApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt-4"}),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_available_models(self, app: Flask):
|
||||
api = ModelProviderAvailableModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_empty_rules(self, app):
|
||||
api = ModelProviderModelParameterRuleApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt"}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["data"] == []
|
||||
|
||||
def test_no_models(self, app):
|
||||
api = ModelProviderAvailableModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
|
||||
assert result["data"] == []
|
||||
1019
api/tests/unit_tests/controllers/console/workspace/test_plugin.py
Normal file
1019
api/tests/unit_tests/controllers/console/workspace/test_plugin.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from controllers.console.workspace.tool_providers import (
|
||||
ToolApiListApi,
|
||||
ToolApiProviderAddApi,
|
||||
ToolApiProviderDeleteApi,
|
||||
ToolApiProviderGetApi,
|
||||
ToolApiProviderGetRemoteSchemaApi,
|
||||
ToolApiProviderListToolsApi,
|
||||
ToolApiProviderUpdateApi,
|
||||
ToolBuiltinListApi,
|
||||
ToolBuiltinProviderAddApi,
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
ToolBuiltinProviderDeleteApi,
|
||||
ToolBuiltinProviderGetCredentialInfoApi,
|
||||
ToolBuiltinProviderGetCredentialsApi,
|
||||
ToolBuiltinProviderGetOauthClientSchemaApi,
|
||||
ToolBuiltinProviderIconApi,
|
||||
ToolBuiltinProviderInfoApi,
|
||||
ToolBuiltinProviderListToolsApi,
|
||||
ToolBuiltinProviderSetDefaultApi,
|
||||
ToolBuiltinProviderUpdateApi,
|
||||
ToolLabelsApi,
|
||||
ToolOAuthCallback,
|
||||
ToolOAuthCustomClient,
|
||||
ToolPluginOAuthApi,
|
||||
ToolProviderListApi,
|
||||
ToolProviderMCPApi,
|
||||
ToolWorkflowListApi,
|
||||
ToolWorkflowProviderCreateApi,
|
||||
ToolWorkflowProviderDeleteApi,
|
||||
ToolWorkflowProviderGetApi,
|
||||
ToolWorkflowProviderUpdateApi,
|
||||
is_valid_url,
|
||||
)
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
|
||||
|
||||
class TestUtils:
|
||||
def test_is_valid_url(self):
|
||||
assert is_valid_url("https://example.com")
|
||||
assert is_valid_url("http://example.com")
|
||||
assert not is_valid_url("")
|
||||
assert not is_valid_url("ftp://example.com")
|
||||
assert not is_valid_url("not-a-url")
|
||||
assert not is_valid_url(None)
|
||||
|
||||
|
||||
class TestToolProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ToolProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
|
||||
return_value=["p1"],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["p1"]
|
||||
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
def test_list_tools(self, app):
|
||||
api = ToolBuiltinProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
|
||||
return_value=[{"a": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == [{"a": 1}]
|
||||
|
||||
def test_info(self, app):
|
||||
api = ToolBuiltinProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"x": 1}
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolBuiltinProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_id": "cid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["result"] == "success"
|
||||
|
||||
def test_add_invalid_type(self, app):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "provider")
|
||||
|
||||
def test_add_success(self, app):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["id"] == 1
|
||||
|
||||
def test_update(self, app):
|
||||
api = ToolBuiltinProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credentials(self, app):
|
||||
api = ToolBuiltinProviderGetCredentialsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
|
||||
return_value={"k": "v"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"k": "v"}
|
||||
|
||||
def test_icon(self, app):
|
||||
api = ToolBuiltinProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon",
|
||||
return_value=(b"x", "image/png"),
|
||||
),
|
||||
):
|
||||
response = method(api, "provider")
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_credentials_schema(self, app):
|
||||
api = ToolBuiltinProviderCredentialsSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider", "oauth2") == {"schema": {}}
|
||||
|
||||
def test_set_default_credential(self, app):
|
||||
api = ToolBuiltinProviderSetDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"id": "c1"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credential_info(self, app):
|
||||
api = ToolBuiltinProviderGetCredentialInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
|
||||
return_value={"info": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"info": "x"}
|
||||
|
||||
def test_get_oauth_client_schema(self, app):
|
||||
api = ToolBuiltinProviderGetOauthClientSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"schema": {}}
|
||||
|
||||
|
||||
class TestApiProviderApis:
|
||||
def test_add(self, app):
|
||||
api = ToolApiProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"icon": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_remote_schema(self, app):
|
||||
api = ToolApiProviderGetRemoteSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?url=http://x.com"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
|
||||
return_value={"schema": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api)["schema"] == "x"
|
||||
|
||||
def test_list_tools(self, app):
|
||||
api = ToolApiProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
|
||||
return_value=[{"tool": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"tool": 1}]
|
||||
|
||||
def test_update(self, app):
|
||||
api = ToolApiProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"original_provider": "o",
|
||||
"icon": {},
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolApiProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"provider": "p"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api)["result"] == "success"
|
||||
|
||||
def test_get(self, app):
|
||||
api = ToolApiProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api) == {"x": 1}
|
||||
|
||||
|
||||
class TestWorkflowApis:
|
||||
def test_create(self, app):
|
||||
api = ToolWorkflowProviderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"workflow_app_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"name": "n",
|
||||
"label": "l",
|
||||
"description": "d",
|
||||
"icon": {},
|
||||
"parameters": [],
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_update_invalid(self, app):
|
||||
api = ToolWorkflowProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"name": "Tool",
|
||||
"label": "Tool Label",
|
||||
"description": "A tool",
|
||||
"icon": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolWorkflowProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_get_error(self, app):
|
||||
api = ToolWorkflowProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestLists:
|
||||
def test_builtin_list(self, app):
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_api_list(self, app):
|
||||
api = ToolApiListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_workflow_list(self, app):
|
||||
api = ToolWorkflowListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
|
||||
class TestLabels:
|
||||
def test_labels(self, app):
|
||||
api = ToolLabelsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
|
||||
return_value=["l1"],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["l1"]
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
def test_oauth_no_client(self, app):
|
||||
api = ToolPluginOAuthApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
|
||||
def test_oauth_callback_no_cookie(self, app):
|
||||
api = ToolOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
|
||||
|
||||
class TestOAuthCustomClient:
|
||||
def test_save_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"client_params": {"a": 1}}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
|
||||
return_value={"client_id": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"client_id": "x"}
|
||||
|
||||
def test_delete_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
@ -0,0 +1,558 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from controllers.console.workspace.trigger_providers import (
|
||||
TriggerOAuthAuthorizeApi,
|
||||
TriggerOAuthCallbackApi,
|
||||
TriggerOAuthClientManageApi,
|
||||
TriggerProviderIconApi,
|
||||
TriggerProviderInfoApi,
|
||||
TriggerProviderListApi,
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
TriggerSubscriptionDeleteApi,
|
||||
TriggerSubscriptionListApi,
|
||||
TriggerSubscriptionUpdateApi,
|
||||
TriggerSubscriptionVerifyApi,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def mock_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "u1"
|
||||
user.current_tenant_id = "t1"
|
||||
return user
|
||||
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
def test_icon_success(self, app):
|
||||
api = TriggerProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
|
||||
return_value="icon",
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == "icon"
|
||||
|
||||
def test_list_providers(self, app):
|
||||
api = TriggerProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api) == []
|
||||
|
||||
def test_provider_info(self, app):
|
||||
api = TriggerProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
|
||||
return_value={"id": "p1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"id": "p1"}
|
||||
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
def test_list_success(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == []
|
||||
|
||||
def test_list_invalid_provider(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, "bad")
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestTriggerSubscriptionBuilderApis:
|
||||
def test_create_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
assert "subscription_builder" in result
|
||||
|
||||
def test_get_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_verify_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {"a": 1}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"ok": True}
|
||||
|
||||
def test_verify_builder_error(self, app):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
side_effect=Exception("err"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github", "b1")
|
||||
|
||||
def test_update_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "n"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_logs(self, app):
|
||||
api = TriggerSubscriptionBuilderLogsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
log = MagicMock()
|
||||
log.model_dump.return_value = {"a": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
|
||||
return_value=[log],
|
||||
),
|
||||
):
|
||||
assert "logs" in method(api, "github", "b1")
|
||||
|
||||
def test_build(self, app):
|
||||
api = TriggerSubscriptionBuilderBuildApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == 200
|
||||
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
def test_update_rename_only(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
sub = MagicMock()
|
||||
sub.provider_id = "github"
|
||||
sub.credential_type = CredentialType.UNAUTHORIZED
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
),
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_update_not_found(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "x")
|
||||
|
||||
def test_update_rebuild(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
sub = MagicMock()
|
||||
sub.provider_id = "github"
|
||||
sub.credential_type = CredentialType.OAUTH2
|
||||
sub.credentials = {}
|
||||
sub.parameters = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
|
||||
),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_delete_subscription(self, app):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls,
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = method(api, "sub1")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_subscription_value_error(self, app):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.Session") as session_cls,
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
session_cls.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "sub1")
|
||||
|
||||
|
||||
class TestTriggerOAuthApis:
|
||||
def test_oauth_authorize_success(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value=MagicMock(id="b1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
|
||||
return_value="ctx",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url",
|
||||
return_value=MagicMock(authorization_url="url"),
|
||||
),
|
||||
):
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_oauth_authorize_no_client(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_forbidden(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_success(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
|
||||
return_value=MagicMock(credentials={"a": 1}, expires_at=1),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder"
|
||||
),
|
||||
):
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 302
|
||||
|
||||
def test_oauth_callback_no_oauth_client(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
|
||||
return_value=ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_empty_credentials(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
|
||||
return_value=ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
|
||||
return_value=MagicMock(credentials=None, expires_at=None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github")
|
||||
|
||||
|
||||
class TestTriggerOAuthClientManageApi:
|
||||
def test_get_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
|
||||
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
assert "configured" in result
|
||||
|
||||
def test_post_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_delete_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_oauth_client_post_value_error(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github")
|
||||
|
||||
|
||||
class TestTriggerSubscriptionVerifyApi:
|
||||
def test_verify_success(self, app):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "s1") == {"ok": True}
|
||||
|
||||
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
|
||||
def test_verify_errors(self, app, raised_exception):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
side_effect=raised_exception,
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github", "s1")
|
||||
@ -0,0 +1,605 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.workspace.workspace import (
|
||||
CustomConfigWorkspaceApi,
|
||||
SwitchWorkspaceApi,
|
||||
TenantApi,
|
||||
TenantListApi,
|
||||
WebappLogoWorkspaceApi,
|
||||
WorkspaceInfoApi,
|
||||
WorkspaceListApi,
|
||||
WorkspacePermissionApi,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestTenantListApi:
|
||||
def test_get_success(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = True
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
|
||||
def test_get_billing_disabled(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
|
||||
|
||||
class TestWorkspaceListApi:
|
||||
def test_get_success(self, app):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow())
|
||||
|
||||
paginate_result = MagicMock(
|
||||
items=[tenant],
|
||||
has_next=False,
|
||||
total=1,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}),
|
||||
patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["total"] == 1
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_has_next_true(self, app):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(
|
||||
id="t1",
|
||||
name="T",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
paginate_result = MagicMock(
|
||||
items=[tenant],
|
||||
has_next=True,
|
||||
total=10,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.db.paginate",
|
||||
return_value=paginate_result,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["has_more"] is True
|
||||
|
||||
|
||||
class TestTenantApi:
|
||||
def test_post_active_tenant(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["id"] == "t1"
|
||||
|
||||
def test_post_archived_with_switch(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
archived = MagicMock(status=TenantStatus.ARCHIVE)
|
||||
new_tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=archived)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert result["id"] == "new"
|
||||
|
||||
def test_post_archived_no_tenant(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
|
||||
):
|
||||
with pytest.raises(Unauthorized):
|
||||
method(api)
|
||||
|
||||
def test_post_info_path(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/info"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(user, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
warn_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestSwitchWorkspaceApi:
|
||||
def test_switch_success(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "t2"}
|
||||
tenant = MagicMock(id="t2")
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
|
||||
),
|
||||
):
|
||||
query_mock.return_value.get.return_value = tenant
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_switch_not_linked(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "bad"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
|
||||
):
|
||||
with pytest.raises(AccountNotLinkTenantError):
|
||||
method(api)
|
||||
|
||||
def test_switch_tenant_not_found(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "missing"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
):
|
||||
query_mock.return_value.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCustomConfigWorkspaceApi:
|
||||
def test_post_success(self, app):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={})
|
||||
|
||||
payload = {"remove_webapp_brand": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_logo_fallback(self, app):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
|
||||
|
||||
payload = {"remove_webapp_brand": False}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.db.get_or_404",
|
||||
return_value=tenant,
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestWebappLogoWorkspaceApi:
|
||||
def test_no_file(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
method(api)
|
||||
|
||||
def test_too_many_files(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": MagicMock(),
|
||||
"extra": MagicMock(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data=data),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(TooManyFilesError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_extension(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = MagicMock(filename="test.txt")
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={"file": file}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
|
||||
def test_upload_success(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
upload = MagicMock(id="file1")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.return_value = upload
|
||||
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file1"
|
||||
|
||||
def test_filename_missing(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
filename="",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
method(api)
|
||||
|
||||
def test_file_too_large(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
method(api)
|
||||
|
||||
def test_service_unsupported_file(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWorkspaceInfoApi:
|
||||
def test_post_success(self, app):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
|
||||
payload = {"name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"name": "New Name"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_no_current_tenant(self, app):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "X"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWorkspacePermissionApi:
|
||||
def test_get_success(self, app):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
permission = MagicMock(
|
||||
workspace_id="t1",
|
||||
allow_member_invite=True,
|
||||
allow_owner_transfer=False,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
|
||||
return_value=permission,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspace_id"] == "t1"
|
||||
|
||||
def test_no_current_tenant(self, app):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, permission):
|
||||
self._permission = permission
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def query(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._permission
|
||||
|
||||
|
||||
def _workspace_module():
|
||||
return importlib.import_module(plugin_permission_required.__module__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, None)
|
||||
|
||||
@plugin_permission_required()
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
Loading…
Reference in New Issue
Block a user