test: unit test case for controllers.console.workspace module (#32181)

This commit is contained in:
rajatagarwal-oss 2026-03-09 14:37:40 +05:30 committed by GitHub
parent 8906ab8e52
commit 497feac48e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 5190 additions and 4 deletions

View File

@ -0,0 +1,341 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailCodeError,
)
from controllers.console.error import AccountInFreezeError
from controllers.console.workspace.account import (
AccountAvatarApi,
AccountDeleteApi,
AccountDeleteVerifyApi,
AccountInitApi,
AccountIntegrateApi,
AccountInterfaceLanguageApi,
AccountInterfaceThemeApi,
AccountNameApi,
AccountPasswordApi,
AccountProfileApi,
AccountTimezoneApi,
ChangeEmailCheckApi,
ChangeEmailResetApi,
CheckEmailUnique,
)
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
InvalidAccountDeletionCodeError,
)
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAccountInitApi:
def test_init_success(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="inactive")
payload = {
"interface_language": "en-US",
"timezone": "UTC",
"invitation_code": "code123",
}
with (
app.test_request_context("/account/init", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.query") as query_mock,
):
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
resp = method(api)
assert resp["result"] == "success"
def test_init_already_initialized(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="active")
with (
app.test_request_context("/account/init"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
):
with pytest.raises(AccountAlreadyInitedError):
method(api)
class TestAccountProfileApi:
def test_get_profile_success(self, app):
api = AccountProfileApi()
method = unwrap(api.get)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/account/profile"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountUpdateApis:
@pytest.mark.parametrize(
("api_cls", "payload"),
[
(AccountNameApi, {"name": "test"}),
(AccountAvatarApi, {"avatar": "img.png"}),
(AccountInterfaceLanguageApi, {"interface_language": "en-US"}),
(AccountInterfaceThemeApi, {"interface_theme": "dark"}),
(AccountTimezoneApi, {"timezone": "UTC"}),
],
)
def test_update_success(self, app, api_cls, payload):
api = api_cls()
method = unwrap(api.post)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountPasswordApi:
def test_password_success(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "old",
"new_password": "new123",
"repeat_new_password": "new123",
}
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
):
result = method(api)
assert result["id"] == "u1"
def test_password_wrong_current(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "bad",
"new_password": "new123",
"repeat_new_password": "new123",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.update_account_password",
side_effect=ServicePwdError(),
),
):
with pytest.raises(CurrentPasswordIncorrectError):
method(api)
class TestAccountIntegrateApi:
def test_get_integrates(self, app):
api = AccountIntegrateApi()
method = unwrap(api.get)
account = MagicMock(id="acc1")
with (
app.test_request_context("/"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
):
scalars_mock.return_value.all.return_value = []
result = method(api)
assert "data" in result
assert len(result["data"]) == 2
class TestAccountDeleteApi:
def test_delete_verify_success(self, app):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
return_value=("token", "1234"),
),
patch(
"controllers.console.workspace.account.AccountService.send_account_deletion_verification_email",
return_value=None,
),
):
result = method(api)
assert result["result"] == "success"
def test_delete_invalid_code(self, app):
api = AccountDeleteApi()
method = unwrap(api.post)
payload = {"token": "t", "code": "x"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
return_value=False,
),
):
with pytest.raises(InvalidAccountDeletionCodeError):
method(api)
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
payload = {"email": "a@test.com", "code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.account.AccountService.get_change_email_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_reset_email_already_used(self, app):
api = ChangeEmailResetApi()
method = unwrap(api.post)
payload = {"new_email": "x@test.com", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
):
with pytest.raises(EmailAlreadyInUseError):
method(api)
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "ok@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True),
):
result = method(api)
assert result["result"] == "success"
def test_email_in_freeze(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "x@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True),
):
with pytest.raises(AccountInFreezeError):
method(api)

View File

@ -0,0 +1,139 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
AgentProviderApi,
AgentProviderListApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAgentProviderListApi:
def test_get_success(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
providers = [{"name": "openai"}, {"name": "anthropic"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=providers,
),
):
result = method(api)
assert result == providers
def test_get_empty_list(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=[],
),
):
result = method(api)
assert result == []
def test_get_account_not_found(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api)
class TestAgentProviderApi:
def test_get_success(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "openai"
provider_data = {"name": "openai", "models": ["gpt-4"]}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=provider_data,
),
):
result = method(api, provider_name)
assert result == provider_data
def test_get_provider_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "unknown"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=None,
),
):
result = method(api, provider_name)
assert result is None
def test_get_account_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api, "openai")

View File

@ -0,0 +1,305 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.workspace.endpoint import (
EndpointCreateApi,
EndpointDeleteApi,
EndpointDisableApi,
EndpointEnableApi,
EndpointListApi,
EndpointListForSinglePluginApi,
EndpointUpdateApi,
)
from core.plugin.impl.exc import PluginPermissionDeniedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user_and_tenant():
return MagicMock(id="u1"), "t1"
@pytest.fixture
def patch_current_account(user_and_tenant):
with patch(
"controllers.console.workspace.endpoint.current_account_with_tenant",
return_value=user_and_tenant,
):
yield
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCreateApi:
def test_create_success(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {"a": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_create_permission_denied(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.endpoint.EndpointService.create_endpoint",
side_effect=PluginPermissionDeniedError("denied"),
),
):
with pytest.raises(ValueError):
method(api)
def test_create_validation_error(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "p1",
"name": "",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
):
result = method(api)
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=0&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
patch(
"controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin",
return_value=[{"id": "e1"}],
),
):
result = method(api)
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDeleteApi:
def test_delete_success(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_delete_invalid_payload(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_delete_service_failure(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointUpdateApi:
def test_update_success(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "new-name",
"settings": {"x": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_update_validation_error(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1", "settings": {}}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
def test_update_service_failure(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "n",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_enable_invalid_payload(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_enable_service_failure(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_disable_invalid_payload(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)

View File

@ -0,0 +1,607 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
MemberNotInTenantError,
NotOwnerError,
OwnerTransferLimitError,
)
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.workspace.members import (
DatasetOperatorMemberListApi,
MemberCancelInviteApi,
MemberInviteEmailApi,
MemberListApi,
MemberUpdateRoleApi,
OwnerTransfer,
OwnerTransferCheckApi,
SendOwnerTransferEmailApi,
)
from services.errors.account import AccountAlreadyInTenantError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestMemberListApi:
def test_get_success(self, app):
api = MemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "m1"
member.name = "Member"
member.email = "member@test.com"
member.avatar = "avatar.png"
member.role = "admin"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = MemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestMemberInviteEmailApi:
def test_invite_success(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
"language": "en-US",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert status == 201
assert result["result"] == "success"
def test_invite_limit_exceeded(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = False
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_already_member(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=AccountAlreadyInTenantError(),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert result["invitation_results"][0]["status"] == "success"
def test_invite_invalid_role(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
payload = {
"emails": ["a@test.com"],
"role": "owner",
}
with app.test_request_context("/", json=payload):
result, status = method(api)
assert status == 400
assert result["code"] == "invalid-role"
def test_invite_generic_exception(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=Exception("boom"),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, _ = method(api)
assert result["invitation_results"][0]["status"] == "failed"
class TestMemberCancelInviteApi:
def test_cancel_success(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 400
def test_cancel_no_permission(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 403
def test_cancel_member_not_in_tenant(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 404
class TestMemberUpdateRoleApi:
def test_update_success(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.update_member_role"),
):
result = method(api, "id")
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
def test_update_invalid_role(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "invalid-role"}
with app.test_request_context("/", json=payload):
result, status = method(api, "id")
assert status == 400
def test_update_member_not_found(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.members.current_account_with_tenant",
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
),
patch("controllers.console.workspace.members.db.session.get", return_value=None),
):
with pytest.raises(HTTPException):
method(api, "id")
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "op1"
member.name = "Operator"
member.email = "operator@test.com"
member.avatar = "avatar.png"
member.role = "operator"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestSendOwnerTransferEmailApi:
def test_send_success(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock(name="ws")
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
),
):
result = method(api)
assert result["result"] == "success"
def test_send_ip_limit(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
):
with pytest.raises(EmailSendIpLimitError):
method(api)
def test_send_not_owner(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/", json={}),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
):
with pytest.raises(NotOwnerError):
method(api)
class TestOwnerTransferCheckApi:
def test_check_invalid_code(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_rate_limited(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=True,
),
):
with pytest.raises(OwnerTransferLimitError):
method(api)
def test_invalid_token(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api)
def test_invalid_email(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "b@test.com", "code": "x"},
),
):
with pytest.raises(InvalidEmailError):
method(api)
class TestOwnerTransferApi:
def test_transfer_self(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
):
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
def test_invalid_token(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
member = MagicMock()
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com"},
),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
):
with pytest.raises(MemberNotInTenantError):
method(api, "2")

View File

@ -0,0 +1,388 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic_core import ValidationError
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.model_providers import (
ModelProviderCredentialApi,
ModelProviderCredentialSwitchApi,
ModelProviderIconApi,
ModelProviderListApi,
ModelProviderPaymentCheckoutUrlApi,
ModelProviderValidateApi,
PreferredProviderTypeUpdateApi,
)
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
INVALID_UUID = "123"
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestModelProviderListApi:
def test_get_success(self, app):
api = ModelProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?model_type=llm"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
return_value=[{"name": "openai"}],
),
):
result = method(api)
assert "data" in result
class TestModelProviderCredentialApi:
def test_get_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(f"/?credential_id={VALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
return_value={"key": "value"},
),
):
result = method(api, provider="openai")
assert "credentials" in result
def test_get_invalid_uuid(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_post_create_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}, "name": "test"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
assert result["result"] == "success"
assert status == 201
def test_post_create_validation_error(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
side_effect=CredentialsValidateFailedError("bad"),
),
):
with pytest.raises(ValueError):
method(api, provider="openai")
def test_put_update_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_put_invalid_uuid(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_delete_success(self, app):
api = ModelProviderCredentialApi()
method = unwrap(api.delete)
payload = {"credential_id": VALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
assert result["result"] == "success"
assert status == 204
class TestModelProviderCredentialSwitchApi:
def test_switch_success(self, app):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
payload = {"credential_id": VALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_switch_invalid_uuid(self, app):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
payload = {"credential_id": INVALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
class TestModelProviderValidateApi:
def test_validate_success(self, app):
api = ModelProviderValidateApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_validate_failure(self, app):
api = ModelProviderValidateApi()
method = unwrap(api.post)
payload = {"credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
side_effect=CredentialsValidateFailedError("bad"),
),
):
result = method(api, provider="openai")
assert result["result"] == "error"
class TestModelProviderIconApi:
def test_icon_success(self, app):
api = ModelProviderIconApi()
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
return_value=(b"123", "image/png"),
),
):
response = api.get("t1", "openai", "logo", "en")
assert response.mimetype == "image/png"
def test_icon_not_found(self, app):
api = ModelProviderIconApi()
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
return_value=(None, None),
),
):
with pytest.raises(ValueError):
api.get("t1", "openai", "logo", "en")
class TestPreferredProviderTypeUpdateApi:
def test_update_success(self, app):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
payload = {"preferred_provider_type": "custom"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
return_value=None,
),
):
result = method(api, provider="openai")
assert result["result"] == "success"
def test_invalid_enum(self, app):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
payload = {"preferred_provider_type": "invalid"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with pytest.raises(ValidationError):
method(api, provider="openai")
class TestModelProviderPaymentCheckoutUrlApi:
def test_checkout_success(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
user = MagicMock(id="u1", email="x@test.com")
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
return_value=None,
),
patch(
"controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link",
return_value={"url": "x"},
),
):
result = method(api, provider="anthropic")
assert "url" in result
def test_invalid_provider(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(ValueError):
method(api, provider="openai")
def test_permission_denied(self, app):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
user = MagicMock(id="u1", email="x@test.com")
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
side_effect=Forbidden(),
),
):
with pytest.raises(Forbidden):
method(api, provider="anthropic")

View File

@ -0,0 +1,447 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.workspace.models import (
DefaultModelApi,
ModelProviderAvailableModelApi,
ModelProviderModelApi,
ModelProviderModelCredentialApi,
ModelProviderModelCredentialSwitchApi,
ModelProviderModelDisableApi,
ModelProviderModelEnableApi,
ModelProviderModelParameterRuleApi,
ModelProviderModelValidateApi,
)
from dify_graph.model_runtime.entities.model_entities import ModelType
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDefaultModelApi:
def test_get_success(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.get)
with (
app.test_request_context(
"/",
query_string={"model_type": ModelType.LLM.value},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
result = method(api)
assert "data" in result
def test_post_success(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.post)
payload = {
"model_settings": [
{
"model_type": ModelType.LLM.value,
"provider": "openai",
"model": "gpt-4",
}
]
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api)
assert result["result"] == "success"
def test_get_returns_empty_when_no_default(self, app):
api = DefaultModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_default_model_of_model_type.return_value = None
result = method(api)
assert "data" in result
class TestModelProviderModelApi:
def test_get_models_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_provider.return_value = []
result = method(api, "openai")
assert "data" in result
def test_post_models_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"load_balancing": {
"configs": [{"weight": 1}],
"enabled": True,
},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
):
result, status = method(api, "openai")
assert status == 200
def test_delete_model_success(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.delete)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 204
def test_get_models_returns_empty(self, app):
api = ModelProviderModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_provider.return_value = []
result = method(api, "openai")
assert "data" in result
class TestModelProviderModelCredentialApi:
def test_get_credentials_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(
"/",
query_string={
"model": "gpt-4",
"model_type": ModelType.LLM.value,
},
),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
):
provider_service.return_value.get_model_credential.return_value = {
"credentials": {},
"current_credential_id": None,
"current_credential_name": None,
}
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "openai")
assert "credentials" in result
def test_create_credential_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credentials": {"key": "val"},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 201
def test_get_empty_credentials(self, app):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
):
service.return_value.get_model_credential.return_value = None
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "openai")
assert result["credentials"] == {}
def test_delete_success(self, app):
api = ModelProviderModelCredentialApi()
method = unwrap(api.delete)
payload = {
"model": "gpt",
"model_type": ModelType.LLM.value,
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result, status = method(api, "openai")
assert status == 204
class TestModelProviderModelCredentialSwitchApi:
def test_switch_success(self, app: Flask):
api = ModelProviderModelCredentialSwitchApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credential_id": "abc",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
class TestModelEnableDisableApis:
def test_enable_model(self, app: Flask):
api = ModelProviderModelEnableApi()
method = unwrap(api.patch)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
def test_disable_model(self, app: Flask):
api = ModelProviderModelDisableApi()
method = unwrap(api.patch)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
class TestModelProviderModelValidateApi:
def test_validate_success(self, app: Flask):
api = ModelProviderModelValidateApi()
method = unwrap(api.post)
payload = {
"model": "gpt-4",
"model_type": ModelType.LLM.value,
"credentials": {"key": "val"},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService"),
):
result = method(api, "openai")
assert result["result"] == "success"
@pytest.mark.parametrize("model_name", ["gpt-4", "gpt"])
def test_validate_failure(self, app: Flask, model_name: str):
api = ModelProviderModelValidateApi()
method = unwrap(api.post)
payload = {
"model": model_name,
"model_type": ModelType.LLM.value,
"credentials": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
result = method(api, "openai")
assert result["result"] == "error"
class TestParameterAndAvailableModels:
def test_parameter_rules(self, app: Flask):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt-4"}),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_model_parameter_rules.return_value = []
result = method(api, "openai")
assert "data" in result
def test_available_models(self, app: Flask):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.models.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
):
service_mock.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
assert "data" in result
def test_empty_rules(self, app):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
with (
app.test_request_context("/", query_string={"model": "gpt"}),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_model_parameter_rules.return_value = []
result = method(api, "openai")
assert result["data"] == []
def test_no_models(self, app):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
patch("controllers.console.workspace.models.ModelProviderService") as service,
):
service.return_value.get_models_by_model_type.return_value = []
result = method(api, ModelType.LLM.value)
assert result["data"] == []

File diff suppressed because it is too large Load Diff

View File

@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_restx import Api
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
from controllers.console.workspace.tool_providers import (
ToolApiListApi,
ToolApiProviderAddApi,
ToolApiProviderDeleteApi,
ToolApiProviderGetApi,
ToolApiProviderGetRemoteSchemaApi,
ToolApiProviderListToolsApi,
ToolApiProviderUpdateApi,
ToolBuiltinListApi,
ToolBuiltinProviderAddApi,
ToolBuiltinProviderCredentialsSchemaApi,
ToolBuiltinProviderDeleteApi,
ToolBuiltinProviderGetCredentialInfoApi,
ToolBuiltinProviderGetCredentialsApi,
ToolBuiltinProviderGetOauthClientSchemaApi,
ToolBuiltinProviderIconApi,
ToolBuiltinProviderInfoApi,
ToolBuiltinProviderListToolsApi,
ToolBuiltinProviderSetDefaultApi,
ToolBuiltinProviderUpdateApi,
ToolLabelsApi,
ToolOAuthCallback,
ToolOAuthCustomClient,
ToolPluginOAuthApi,
ToolProviderListApi,
ToolProviderMCPApi,
ToolWorkflowListApi,
ToolWorkflowProviderCreateApi,
ToolWorkflowProviderDeleteApi,
ToolWorkflowProviderGetApi,
ToolWorkflowProviderUpdateApi,
is_valid_url,
)
from core.db.session_factory import configure_session_factory
from extensions.ext_database import db
from services.tools.mcp_tools_manage_service import ReconnectResult
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
# They are intentionally no-ops because the test already patches the required
# behaviors explicitly via @patch and context managers below.
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def _mock_cache():
return
@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
# 若 transform 后包含 tools 字段,确保非空
assert isinstance(body.get("tools"), list)
assert body["tools"]
class TestUtils:
def test_is_valid_url(self):
assert is_valid_url("https://example.com")
assert is_valid_url("http://example.com")
assert not is_valid_url("")
assert not is_valid_url("ftp://example.com")
assert not is_valid_url("not-a-url")
assert not is_valid_url(None)
class TestToolProviderListApi:
def test_get_success(self, app):
api = ToolProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
),
patch(
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
return_value=["p1"],
),
):
assert method(api) == ["p1"]
class TestBuiltinProviderApis:
def test_list_tools(self, app):
api = ToolBuiltinProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
return_value=[{"a": 1}],
),
):
assert method(api, "provider") == [{"a": 1}]
def test_info(self, app):
api = ToolBuiltinProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
return_value={"x": 1},
),
):
assert method(api, "provider") == {"x": 1}
def test_delete(self, app):
api = ToolBuiltinProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_id": "cid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
return_value={"result": "success"},
),
):
assert method(api, "provider")["result"] == "success"
def test_add_invalid_type(self, app):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
):
with pytest.raises(ValueError):
method(api, "provider")
def test_add_success(self, app):
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
return_value={"id": 1},
),
):
assert method(api, "provider")["id"] == 1
def test_update(self, app):
api = ToolBuiltinProviderUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_credentials(self, app):
api = ToolBuiltinProviderGetCredentialsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
return_value={"k": "v"},
),
):
assert method(api, "provider") == {"k": "v"}
def test_icon(self, app):
api = ToolBuiltinProviderIconApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon",
return_value=(b"x", "image/png"),
),
):
response = method(api, "provider")
assert response.mimetype == "image/png"
def test_credentials_schema(self, app):
api = ToolBuiltinProviderCredentialsSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider", "oauth2") == {"schema": {}}
def test_set_default_credential(self, app):
api = ToolBuiltinProviderSetDefaultApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"id": "c1"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_credential_info(self, app):
api = ToolBuiltinProviderGetCredentialInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
return_value={"info": "x"},
),
):
assert method(api, "provider") == {"info": "x"}
def test_get_oauth_client_schema(self, app):
api = ToolBuiltinProviderGetOauthClientSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider") == {"schema": {}}
class TestApiProviderApis:
def test_add(self, app):
api = ToolApiProviderAddApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"icon": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
def test_remote_schema(self, app):
api = ToolApiProviderGetRemoteSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/?url=http://x.com"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
return_value={"schema": "x"},
),
):
assert method(api)["schema"] == "x"
def test_list_tools(self, app):
api = ToolApiProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
return_value=[{"tool": 1}],
),
):
assert method(api) == [{"tool": 1}]
def test_update(self, app):
api = ToolApiProviderUpdateApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"original_provider": "o",
"icon": {},
"privacy_policy": "",
"custom_disclaimer": "",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
return_value={"ok": True},
),
):
assert method(api)["ok"]
def test_delete(self, app):
api = ToolApiProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"provider": "p"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
return_value={"result": "success"},
),
):
assert method(api)["result"] == "success"
def test_get(self, app):
api = ToolApiProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
return_value={"x": 1},
),
):
assert method(api) == {"x": 1}
class TestWorkflowApis:
def test_create(self, app):
api = ToolWorkflowProviderCreateApi()
method = unwrap(api.post)
payload = {
"workflow_app_id": "123e4567-e89b-12d3-a456-426614174000",
"name": "n",
"label": "l",
"description": "d",
"icon": {},
"parameters": [],
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
def test_update_invalid(self, app):
api = ToolWorkflowProviderUpdateApi()
method = unwrap(api.post)
payload = {
"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000",
"name": "Tool",
"label": "Tool Label",
"description": "A tool",
"icon": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
return_value={"ok": True},
),
):
result = method(api)
assert result["ok"]
def test_delete(self, app):
api = ToolWorkflowProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
return_value={"ok": True},
),
):
assert method(api)["ok"]
def test_get_error(self, app):
api = ToolWorkflowProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
):
with pytest.raises(ValueError):
method(api)
class TestLists:
def test_builtin_list(self, app):
api = ToolBuiltinListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
def test_api_list(self, app):
api = ToolApiListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
def test_workflow_list(self, app):
api = ToolWorkflowListApi()
method = unwrap(api.get)
m = MagicMock()
m.to_dict.return_value = {"x": 1}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
class TestLabels:
def test_labels(self, app):
api = ToolLabelsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
return_value=["l1"],
),
):
assert method(api) == ["l1"]
class TestOAuth:
def test_oauth_no_client(self, app):
api = ToolPluginOAuthApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "provider")
def test_oauth_callback_no_cookie(self, app):
api = ToolOAuthCallback()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "provider")
class TestOAuthCustomClient:
def test_save_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"client_params": {"a": 1}}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
def test_get_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
return_value={"client_id": "x"},
),
):
assert method(api, "provider") == {"client_id": "x"}
def test_delete_custom_client(self, app):
api = ToolOAuthCustomClient()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]

View File

@ -0,0 +1,558 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden
from controllers.console.workspace.trigger_providers import (
TriggerOAuthAuthorizeApi,
TriggerOAuthCallbackApi,
TriggerOAuthClientManageApi,
TriggerProviderIconApi,
TriggerProviderInfoApi,
TriggerProviderListApi,
TriggerSubscriptionBuilderBuildApi,
TriggerSubscriptionBuilderCreateApi,
TriggerSubscriptionBuilderGetApi,
TriggerSubscriptionBuilderLogsApi,
TriggerSubscriptionBuilderUpdateApi,
TriggerSubscriptionBuilderVerifyApi,
TriggerSubscriptionDeleteApi,
TriggerSubscriptionListApi,
TriggerSubscriptionUpdateApi,
TriggerSubscriptionVerifyApi,
)
from controllers.web.error import NotFoundError
from core.plugin.entities.plugin_daemon import CredentialType
from models.account import Account
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def mock_user():
user = MagicMock(spec=Account)
user.id = "u1"
user.current_tenant_id = "t1"
return user
class TestTriggerProviderApis:
def test_icon_success(self, app):
api = TriggerProviderIconApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
return_value="icon",
),
):
assert method(api, "github") == "icon"
def test_list_providers(self, app):
api = TriggerProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
return_value=[],
),
):
assert method(api) == []
def test_provider_info(self, app):
api = TriggerProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
return_value={"id": "p1"},
),
):
assert method(api, "github") == {"id": "p1"}
class TestTriggerSubscriptionListApi:
def test_list_success(self, app):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
return_value=[],
),
):
assert method(api, "github") == []
def test_list_invalid_provider(self, app):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
side_effect=ValueError("bad"),
),
):
result, status = method(api, "bad")
assert status == 404
class TestTriggerSubscriptionBuilderApis:
def test_create_builder(self, app):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
result = method(api, "github")
assert "subscription_builder" in result
def test_get_builder(self, app):
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
return_value={"id": "b1"},
),
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {"a": 1}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
return_value={"ok": True},
),
):
assert method(api, "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
side_effect=Exception("err"),
),
):
with pytest.raises(ValueError):
method(api, "github", "b1")
def test_update_builder(self, app):
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "n"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
return_value={"id": "b1"},
),
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_logs(self, app):
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
log = MagicMock()
log.model_dump.return_value = {"a": 1}
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
return_value=[log],
),
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app):
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
return_value=None,
),
):
assert method(api, "github", "b1") == 200
class TestTriggerSubscriptionCrud:
def test_update_rename_only(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
sub = MagicMock()
sub.provider_id = "github"
sub.credential_type = CredentialType.UNAUTHORIZED
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
),
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
):
assert method(api, "s1") == 200
def test_update_not_found(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"name": "x"}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "x")
def test_update_rebuild(self, app):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
sub = MagicMock()
sub.provider_id = "github"
sub.credential_type = CredentialType.OAUTH2
sub.credentials = {}
sub.parameters = {}
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
return_value=sub,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
),
):
assert method(api, "s1") == 200
def test_delete_subscription(self, app):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
mock_session = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls,
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
),
):
mock_db.engine = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
result = method(api, "sub1")
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as session_cls,
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
side_effect=ValueError("bad"),
),
):
mock_db.engine = MagicMock()
session_cls.return_value.__enter__.return_value = MagicMock()
with pytest.raises(BadRequest):
method(api, "sub1")
class TestTriggerOAuthApis:
def test_oauth_authorize_success(self, app):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
return_value=MagicMock(id="b1"),
),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
return_value="ctx",
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url",
return_value=MagicMock(authorization_url="url"),
),
):
resp = method(api, "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(NotFoundError):
method(api, "github")
def test_oauth_callback_forbidden(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
return_value=MagicMock(credentials={"a": 1}, expires_at=1),
),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder"
),
):
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
return_value=ctx,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
ctx = {
"user_id": "u1",
"tenant_id": "t1",
"subscription_builder_id": "b1",
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch(
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
return_value=ctx,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
return_value={"a": 1},
),
patch(
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
return_value=MagicMock(credentials=None, expires_at=None),
),
):
with pytest.raises(ValueError):
method(api, "github")
class TestTriggerOAuthClientManageApi:
def test_get_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
return_value={},
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled",
return_value=False,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists",
return_value=True,
),
patch(
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
),
):
result = method(api, "github")
assert "configured" in result
def test_post_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
def test_delete_client(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"enabled": True}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
side_effect=ValueError("bad"),
),
):
with pytest.raises(BadRequest):
method(api, "github")
class TestTriggerSubscriptionVerifyApi:
def test_verify_success(self, app):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
return_value={"ok": True},
),
):
assert method(api, "github", "s1") == {"ok": True}
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
def test_verify_errors(self, app, raised_exception):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}}),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
side_effect=raised_exception,
),
):
with pytest.raises(BadRequest):
method(api, "github", "s1")

View File

@ -0,0 +1,605 @@
from datetime import datetime
from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Unauthorized
import services
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.workspace.workspace import (
CustomConfigWorkspaceApi,
SwitchWorkspaceApi,
TenantApi,
TenantListApi,
WebappLogoWorkspaceApi,
WorkspaceInfoApi,
WorkspaceListApi,
WorkspacePermissionApi,
)
from enums.cloud_plan import CloudPlan
from models.account import TenantStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestTenantListApi:
def test_get_success(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
features = MagicMock()
features.billing.enabled = True
features.billing.subscription.plan = CloudPlan.SANDBOX
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
def test_get_billing_disabled(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant = MagicMock(
id="t1",
name="Tenant",
status="active",
created_at=datetime.utcnow(),
)
features = MagicMock()
features.billing.enabled = False
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant],
),
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
),
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
class TestWorkspaceListApi:
def test_get_success(self, app):
api = WorkspaceListApi()
method = unwrap(api.get)
tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow())
paginate_result = MagicMock(
items=[tenant],
has_next=False,
total=1,
)
with (
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}),
patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result),
):
result, status = method(api)
assert status == 200
assert result["total"] == 1
assert result["has_more"] is False
def test_get_has_next_true(self, app):
api = WorkspaceListApi()
method = unwrap(api.get)
tenant = MagicMock(
id="t1",
name="T",
status="active",
created_at=datetime.utcnow(),
)
paginate_result = MagicMock(
items=[tenant],
has_next=True,
total=10,
)
with (
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}),
patch(
"controllers.console.workspace.workspace.db.paginate",
return_value=paginate_result,
),
):
result, status = method(api)
assert status == 200
assert result["has_more"] is True
class TestTenantApi:
def test_post_active_tenant(self, app):
api = TenantApi()
method = unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result, status = method(api)
assert status == 200
assert result["id"] == "t1"
def test_post_archived_with_switch(self, app):
api = TenantApi()
method = unwrap(api.post)
archived = MagicMock(status=TenantStatus.ARCHIVE)
new_tenant = MagicMock(status="active")
user = MagicMock(current_tenant=archived)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
),
):
result, status = method(api)
assert result["id"] == "new"
def test_post_archived_no_tenant(self, app):
api = TenantApi()
method = unwrap(api.post)
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
):
with pytest.raises(Unauthorized):
method(api)
def test_post_info_path(self, app):
api = TenantApi()
method = unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/info"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(user, "t1"),
),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"id": "t1"},
),
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
):
result, status = method(api)
warn_mock.assert_called_once()
assert status == 200
class TestSwitchWorkspaceApi:
def test_switch_success(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "t2"}
tenant = MagicMock(id="t2")
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
),
):
query_mock.return_value.get.return_value = tenant
result = method(api)
assert result["result"] == "success"
def test_switch_not_linked(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "bad"}
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
):
with pytest.raises(AccountNotLinkTenantError):
method(api)
def test_switch_tenant_not_found(self, app):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
payload = {"tenant_id": "missing"}
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
):
query_mock.return_value.get.return_value = None
with pytest.raises(ValueError):
method(api)
class TestCustomConfigWorkspaceApi:
def test_post_success(self, app):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
tenant = MagicMock(custom_config_dict={})
payload = {"remove_webapp_brand": True}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result = method(api)
assert result["result"] == "success"
def test_logo_fallback(self, app):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
payload = {"remove_webapp_brand": False}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.db.get_or_404",
return_value=tenant,
),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"id": "t1"},
),
):
result = method(api)
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
assert result["result"] == "success"
class TestWebappLogoWorkspaceApi:
def test_no_file(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
with (
app.test_request_context("/upload", data={}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with pytest.raises(NoFileUploadedError):
method(api)
def test_too_many_files(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
data = {
"file": MagicMock(),
"extra": MagicMock(),
}
with (
app.test_request_context("/upload", data=data),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
):
with pytest.raises(TooManyFilesError):
method(api)
def test_invalid_extension(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = MagicMock(filename="test.txt")
with (
app.test_request_context("/upload", data={"file": file}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with pytest.raises(UnsupportedFileTypeError):
method(api)
def test_upload_success(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
filename="logo.png",
content_type="image/png",
)
upload = MagicMock(id="file1")
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.return_value = upload
result, status = method(api)
assert status == 201
assert result["id"] == "file1"
def test_filename_missing(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
filename="",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
):
with pytest.raises(FilenameNotExistsError):
method(api)
def test_file_too_large(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
with pytest.raises(FileTooLargeError):
method(api)
def test_service_unsupported_file(self, app):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
with pytest.raises(UnsupportedFileTypeError):
method(api)
class TestWorkspaceInfoApi:
def test_post_success(self, app):
api = WorkspaceInfoApi()
method = unwrap(api.post)
tenant = MagicMock()
payload = {"name": "New Name"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"name": "New Name"},
),
):
result = method(api)
assert result["result"] == "success"
def test_no_current_tenant(self, app):
api = WorkspaceInfoApi()
method = unwrap(api.post)
payload = {"name": "X"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(ValueError):
method(api)
class TestWorkspacePermissionApi:
def test_get_success(self, app):
api = WorkspacePermissionApi()
method = unwrap(api.get)
permission = MagicMock(
workspace_id="t1",
allow_member_invite=True,
allow_owner_transfer=False,
)
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
return_value=permission,
),
):
result, status = method(api)
assert status == 200
assert result["workspace_id"] == "t1"
def test_no_current_tenant(self, app):
api = WorkspacePermissionApi()
method = unwrap(api.get)
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(ValueError):
method(api)

View File

@ -0,0 +1,142 @@
from __future__ import annotations
import importlib
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.workspace import plugin_permission_required
from models.account import TenantPluginPermission
class _SessionStub:
def __init__(self, permission):
self._permission = permission
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def query(self, *_args, **_kwargs):
return self
def where(self, *_args, **_kwargs):
return self
def first(self):
return self._permission
def _workspace_module():
return importlib.import_module(plugin_permission_required.__module__)
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
module = _workspace_module()
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, None)
@plugin_permission_required()
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(install_required=True)
def handler():
return "ok"
assert handler() == "ok"
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=True)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(is_admin_or_owner=False)
permission = SimpleNamespace(
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
)
module = _workspace_module()
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
_patch_session(monkeypatch, permission)
@plugin_permission_required(debug_required=True)
def handler():
return "ok"
with pytest.raises(Forbidden):
handler()